diff --git a/Directory.Build.targets b/Directory.Build.targets
index 181a4b228..ad9b19cce 100644
--- a/Directory.Build.targets
+++ b/Directory.Build.targets
@@ -43,5 +43,10 @@
+
+
+
+
+
diff --git a/LSP.sln b/LSP.sln
index 0543be637..3cc10f0f9 100644
--- a/LSP.sln
+++ b/LSP.sln
@@ -44,11 +44,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Protocol", "src\Protocol\Pr
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Server", "src\Server\Server.csproj", "{E540868F-438E-4F7F-BBB7-010D6CB18A57}"
EndProject
-Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{7E4A7675-45F3-4636-88AC-2C158C1A140D}"
- ProjectSection(SolutionItems) = preProject
- cake.config = cake.config
- EndProjectSection
-EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Dap.Protocol", "src\Dap.Protocol\Dap.Protocol.csproj", "{F2C9D555-118E-442B-A953-9A7B58A53F33}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Dap.Server", "src\Dap.Server\Dap.Server.csproj", "{E1A9123B-A236-4240-8C82-A61BD85C3BF4}"
@@ -73,6 +68,10 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Testing", "src\Testing\Test
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "JsonRpc.Testing", "src\JsonRpc.Testing\JsonRpc.Testing.csproj", "{202BA1AB-25DA-44ED-B962-FD82FCC74543}"
EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "JsonRpc.Generators", "src\JsonRpc.Generators\JsonRpc.Generators.csproj", "{DE259174-73DC-4532-B641-AD218971EE29}"
+EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Generation.Tests", "test\Generation.Tests\Generation.Tests.csproj", "{671FFF78-BDD2-4389-B29C-BFD183DA9120}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -293,6 +292,30 @@ Global
{202BA1AB-25DA-44ED-B962-FD82FCC74543}.Release|x64.Build.0 = Release|Any CPU
{202BA1AB-25DA-44ED-B962-FD82FCC74543}.Release|x86.ActiveCfg = Release|Any CPU
{202BA1AB-25DA-44ED-B962-FD82FCC74543}.Release|x86.Build.0 = Release|Any CPU
+ {DE259174-73DC-4532-B641-AD218971EE29}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {DE259174-73DC-4532-B641-AD218971EE29}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {DE259174-73DC-4532-B641-AD218971EE29}.Debug|x64.ActiveCfg = Debug|Any CPU
+ {DE259174-73DC-4532-B641-AD218971EE29}.Debug|x64.Build.0 = Debug|Any CPU
+ {DE259174-73DC-4532-B641-AD218971EE29}.Debug|x86.ActiveCfg = Debug|Any CPU
+ {DE259174-73DC-4532-B641-AD218971EE29}.Debug|x86.Build.0 = Debug|Any CPU
+ {DE259174-73DC-4532-B641-AD218971EE29}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {DE259174-73DC-4532-B641-AD218971EE29}.Release|Any CPU.Build.0 = Release|Any CPU
+ {DE259174-73DC-4532-B641-AD218971EE29}.Release|x64.ActiveCfg = Release|Any CPU
+ {DE259174-73DC-4532-B641-AD218971EE29}.Release|x64.Build.0 = Release|Any CPU
+ {DE259174-73DC-4532-B641-AD218971EE29}.Release|x86.ActiveCfg = Release|Any CPU
+ {DE259174-73DC-4532-B641-AD218971EE29}.Release|x86.Build.0 = Release|Any CPU
+ {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Debug|x64.ActiveCfg = Debug|Any CPU
+ {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Debug|x64.Build.0 = Debug|Any CPU
+ {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Debug|x86.ActiveCfg = Debug|Any CPU
+ {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Debug|x86.Build.0 = Debug|Any CPU
+ {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Release|Any CPU.Build.0 = Release|Any CPU
+ {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Release|x64.ActiveCfg = Release|Any CPU
+ {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Release|x64.Build.0 = Release|Any CPU
+ {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Release|x86.ActiveCfg = Release|Any CPU
+ {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Release|x86.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@@ -316,6 +339,8 @@ Global
{91919C54-3638-4A3C-963A-327D78368EE3} = {D764E024-3D3F-4112-B932-2DB722A1BACC}
{A1EC39EE-AA1F-4EC9-9939-28C3532585C9} = {D764E024-3D3F-4112-B932-2DB722A1BACC}
{202BA1AB-25DA-44ED-B962-FD82FCC74543} = {D764E024-3D3F-4112-B932-2DB722A1BACC}
+ {DE259174-73DC-4532-B641-AD218971EE29} = {D764E024-3D3F-4112-B932-2DB722A1BACC}
+ {671FFF78-BDD2-4389-B29C-BFD183DA9120} = {2F323ED5-EBF8-45E1-B9D3-C014561B3DDA}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {D38DD0EC-D095-4BCD-B8AF-2D788AF3B9AE}
diff --git a/src/Dap.Protocol/Dap.Protocol.csproj b/src/Dap.Protocol/Dap.Protocol.csproj
index 9cb4da29c..d25616d5f 100644
--- a/src/Dap.Protocol/Dap.Protocol.csproj
+++ b/src/Dap.Protocol/Dap.Protocol.csproj
@@ -10,6 +10,8 @@
+
+
<_Parameter1>OmniSharp.Extensions.LanguageServer, PublicKey=0024000004800000940000000602000000240000525341310004000001000100391db875e68eb4bfef49ce14313b9e13f2cd3cc89eb273bbe6c11a55044c7d4f566cf092e1c77ef9e7c75b1496ae7f95d925938f5a01793dd8d9f99ae0a7595779b71b971287d7d7b5960d052078d14f5ce1a85ea5c9fb2f59ac735ff7bc215cab469b7c3486006860bad6f4c3b5204ea2f28dd4e1d05e2cca462cfd593b9f9f
diff --git a/src/JsonRpc.Generation/JsonRpc.Generation.csproj b/src/JsonRpc.Generation/JsonRpc.Generation.csproj
new file mode 100644
index 000000000..c053a21df
--- /dev/null
+++ b/src/JsonRpc.Generation/JsonRpc.Generation.csproj
@@ -0,0 +1,15 @@
+
+
+
+
+ netstandard1.0
+ OmniSharp.Extensions.JsonRpc.Generation
+ OmniSharp.Extensions.JsonRpc.Generation
+
+
+
+
+
+
+
+
diff --git a/src/JsonRpc.Generators/GenerateHandlerMethodsGenerator.cs b/src/JsonRpc.Generators/GenerateHandlerMethodsGenerator.cs
new file mode 100644
index 000000000..f186184fd
--- /dev/null
+++ b/src/JsonRpc.Generators/GenerateHandlerMethodsGenerator.cs
@@ -0,0 +1,457 @@
+using System;
+using System.Buffers;
+using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using CodeGeneration.Roslyn;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
+using static OmniSharp.Extensions.JsonRpc.Generators.Helpers;
+
+namespace OmniSharp.Extensions.JsonRpc.Generators
+{
+ public class GenerateHandlerMethodsGenerator : IRichCodeGenerator
+ {
+ private readonly AttributeData _attributeData;
+
+ public GenerateHandlerMethodsGenerator(AttributeData attributeData)
+ {
+ _attributeData = attributeData;
+ }
+
+ public Task GenerateRichAsync(TransformationContext context, IProgress progress, CancellationToken cancellationToken)
+ {
+ if (!(context.ProcessingNode is InterfaceDeclarationSyntax handlerInterface))
+ {
+ return Task.FromResult(new RichGenerationResult());
+ }
+
+ var methods = new List();
+ var additionalUsings = new HashSet() {
+ "System",
+ "System.Collections.Generic",
+ "System.Threading",
+ "System.Threading.Tasks",
+ "MediatR",
+ "Microsoft.Extensions.DependencyInjection",
+ };
+ var symbol = context.SemanticModel.GetDeclaredSymbol(handlerInterface);
+
+ var className = GetExtensionClassName(symbol);
+
+ foreach (var registry in GetRegistries(_attributeData, handlerInterface, symbol, context, progress, additionalUsings))
+ {
+ if (IsNotification(symbol))
+ {
+ var requestType = GetRequestType(symbol);
+ methods.AddRange(HandleNotifications(handlerInterface, symbol, requestType, registry, additionalUsings));
+ }
+ else if (IsRequest(symbol))
+ {
+ var requestType = GetRequestType(symbol);
+ var responseType = GetResponseType(symbol);
+ methods.AddRange(HandleRequest(handlerInterface, symbol, requestType, responseType, registry, additionalUsings));
+ }
+ }
+
+ var existingUsings = context.CompilationUnitUsings
+ .Join(additionalUsings, z => z.Name.ToFullString(), z => z, (a, b) => b)
+ ;
+
+ var newUsings = additionalUsings
+ .Except(existingUsings)
+ .Select(z => UsingDirective(IdentifierName(z)))
+ ;
+
+ var attributes = List(new[] {
+ AttributeList(SeparatedList(new[] {
+ Attribute(ParseName("System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute")),
+ Attribute(ParseName("System.Runtime.CompilerServices.CompilerGeneratedAttribute")),
+ }))
+ });
+
+ return Task.FromResult(new RichGenerationResult() {
+ Usings = List(newUsings),
+ Members = List(new[] {
+ NamespaceDeclaration(ParseName(symbol.ContainingNamespace.ToDisplayString()))
+
+ .WithMembers(List(new MemberDeclarationSyntax[] {
+ ClassDeclaration(className)
+ .WithAttributeLists(attributes)
+ .WithModifiers(TokenList(
+ Token(SyntaxKind.PublicKeyword),
+ Token(SyntaxKind.StaticKeyword),
+ Token(SyntaxKind.PartialKeyword)
+ ))
+ .WithMembers(List(methods))
+ .NormalizeWhitespace()
+ }))
+ })
+ });
+ }
+
+ IEnumerable HandleNotifications(
+ InterfaceDeclarationSyntax handlerInterface,
+ INamedTypeSymbol interfaceType,
+ INamedTypeSymbol requestType,
+ NameSyntax registryType,
+ HashSet additionalUsings)
+ {
+ var methodName = GetOnMethodName(interfaceType, _attributeData);
+
+ var parameters = ParameterList(SeparatedList(new[] {
+ Parameter(Identifier("registry"))
+ .WithType(registryType)
+ .WithModifiers(TokenList(Token(SyntaxKind.ThisKeyword)))
+ }));
+
+ var capability = GetCapability(interfaceType);
+ var registrationOptions = GetRegistrationOptions(interfaceType);
+ if (capability != null) additionalUsings.Add(capability.ContainingNamespace.ToDisplayString());
+ if (registrationOptions != null) additionalUsings.Add(registrationOptions.ContainingNamespace.ToDisplayString());
+
+ if (registrationOptions == null)
+ {
+ var method = MethodDeclaration(registryType, methodName)
+ .WithModifiers(TokenList(
+ Token(SyntaxKind.PublicKeyword),
+ Token(SyntaxKind.StaticKeyword))
+ )
+ .WithExpressionBody(GetNotificationHandlerExpression(GetMethodName(handlerInterface)))
+ .WithSemicolonToken(Token(SyntaxKind.SemicolonToken));
+
+ MemberDeclarationSyntax MakeAction(TypeSyntax syntax)
+ {
+ return method
+ .WithParameterList(parameters.AddParameters(Parameter(Identifier("handler"))
+ .WithType(syntax)))
+ .NormalizeWhitespace();
+ }
+
+ yield return MakeAction(CreateAction(false, requestType));
+ yield return MakeAction(CreateAsyncAction(false, requestType));
+ yield return MakeAction(CreateAction(true, requestType));
+ yield return MakeAction(CreateAsyncAction(true, requestType));
+ if (capability != null)
+ {
+ yield return MakeAction(CreateAction(requestType, capability));
+ yield return MakeAction(CreateAsyncAction(requestType, capability));
+ yield return MakeAction(CreateAction(requestType, capability));
+ yield return MakeAction(CreateAsyncAction(requestType, capability));
+ }
+ }
+ else
+ {
+ var method = MethodDeclaration(registryType, methodName)
+ .WithModifiers(TokenList(
+ Token(SyntaxKind.PublicKeyword),
+ Token(SyntaxKind.StaticKeyword))
+ )
+ .WithBody(GetNotificationRegistrationHandlerExpression(GetMethodName(handlerInterface), requestType, registrationOptions));
+
+ var registrationParameter = Parameter(Identifier("registrationOptions"))
+ .WithType(IdentifierName(registrationOptions.Name));
+
+ MemberDeclarationSyntax MakeAction(TypeSyntax syntax)
+ {
+ return method
+ .WithParameterList(parameters.WithParameters(SeparatedList(parameters.Parameters.Concat(
+ new[] {Parameter(Identifier("handler")).WithType(syntax), registrationParameter}))))
+ .NormalizeWhitespace();
+ }
+
+ yield return MakeAction(CreateAction(false, requestType));
+ yield return MakeAction(CreateAsyncAction(false, requestType));
+ yield return MakeAction(CreateAction(true, requestType));
+ yield return MakeAction(CreateAsyncAction(true, requestType));
+ if (capability != null)
+ {
+ method = method.WithBody(
+ GetNotificationRegistrationHandlerExpression(GetMethodName(handlerInterface), requestType, registrationOptions, capability));
+ yield return MakeAction(CreateAction(requestType, capability));
+ yield return MakeAction(CreateAsyncAction(requestType, capability));
+ }
+ }
+ }
+
+ IEnumerable HandleRequest(
+ InterfaceDeclarationSyntax handlerInterface,
+ INamedTypeSymbol interfaceType,
+ INamedTypeSymbol requestType,
+ INamedTypeSymbol responseType,
+ NameSyntax registryType,
+ HashSet additionalUsings)
+ {
+ var methodName = GetOnMethodName(interfaceType, _attributeData);
+
+ var capability = GetCapability(interfaceType);
+ var registrationOptions = GetRegistrationOptions(interfaceType);
+ var partialItems = GetPartialItems(requestType);
+ var partialItem = GetPartialItem(requestType);
+ if (capability != null) additionalUsings.Add(capability.ContainingNamespace.ToDisplayString());
+ if (registrationOptions != null) additionalUsings.Add(registrationOptions.ContainingNamespace.ToDisplayString());
+ if (partialItems != null) additionalUsings.Add(partialItems.ContainingNamespace.ToDisplayString());
+ if (partialItem != null) additionalUsings.Add(partialItem.ContainingNamespace.ToDisplayString());
+
+ var parameters = ParameterList(SeparatedList(new[] {
+ Parameter(Identifier("registry"))
+ .WithType(registryType)
+ .WithModifiers(TokenList(Token(SyntaxKind.ThisKeyword)))
+ }));
+
+
+ if (registrationOptions == null)
+ {
+ var method = MethodDeclaration(registryType, methodName)
+ .WithModifiers(TokenList(
+ Token(SyntaxKind.PublicKeyword),
+ Token(SyntaxKind.StaticKeyword))
+ )
+ .WithExpressionBody(GetRequestHandlerExpression(GetMethodName(handlerInterface)))
+ .WithSemicolonToken(Token(SyntaxKind.SemicolonToken));
+
+ MemberDeclarationSyntax MakeAction(TypeSyntax syntax)
+ {
+ return method
+ .WithParameterList(parameters.AddParameters(Parameter(Identifier("handler"))
+ .WithType(syntax)))
+ .NormalizeWhitespace();
+ }
+
+ yield return MakeAction(CreateAsyncFunc(responseType, false, requestType));
+ yield return MakeAction(CreateAsyncFunc(responseType, true, requestType));
+ if (partialItems != null)
+ {
+ var partialTypeSyntax = ResolveTypeName(partialItems);
+ var partialItemsSyntax = GenericName("IEnumerable").WithTypeArgumentList(TypeArgumentList(SeparatedList(new[] {partialTypeSyntax})));
+
+ method = method.WithExpressionBody(GetPartialResultsHandlerExpression(GetMethodName(handlerInterface), requestType, partialTypeSyntax, responseType));
+
+ yield return MakeAction(CreatePartialAction(requestType, partialItemsSyntax, true));
+ yield return MakeAction(CreatePartialAction(requestType, partialItemsSyntax, false));
+ if (capability != null)
+ {
+ method = method.WithExpressionBody(GetPartialResultsCapabilityHandlerExpression(GetMethodName(handlerInterface), requestType, responseType,
+ partialTypeSyntax, capability));
+ yield return MakeAction(CreatePartialAction(requestType, partialItemsSyntax, capability));
+ }
+ }
+
+ if (partialItem != null)
+ {
+ var partialTypeSyntax = ResolveTypeName(partialItem);
+
+ method = method.WithExpressionBody(GetPartialResultHandlerExpression(GetMethodName(handlerInterface), requestType, responseType));
+
+ yield return MakeAction(CreatePartialAction(requestType, partialTypeSyntax, true));
+ yield return MakeAction(CreatePartialAction(requestType, partialTypeSyntax, false));
+ if (capability != null)
+ {
+ method = method.WithExpressionBody(GetPartialResultCapabilityHandlerExpression(GetMethodName(handlerInterface), requestType, responseType, capability));
+ yield return MakeAction(CreatePartialAction(requestType, partialTypeSyntax, capability));
+ }
+ }
+
+ if (capability != null)
+ {
+ method = method.WithExpressionBody(
+ GetRequestCapabilityHandlerExpression(GetMethodName(handlerInterface), requestType, responseType, capability));
+ yield return MakeAction(CreateAsyncFunc(responseType, requestType, capability));
+ }
+ }
+ else
+ {
+ var method = MethodDeclaration(registryType, methodName)
+ .WithModifiers(TokenList(
+ Token(SyntaxKind.PublicKeyword),
+ Token(SyntaxKind.StaticKeyword))
+ )
+ .WithBody(GetRequestRegistrationHandlerExpression(GetMethodName(handlerInterface), requestType, responseType, registrationOptions));
+ if (responseType.Name == "Unit")
+ {
+ method = method.WithBody(GetVoidRequestRegistrationHandlerExpression(GetMethodName(handlerInterface), requestType, registrationOptions));
+ }
+
+ var registrationParameter = Parameter(Identifier("registrationOptions"))
+ .WithType(IdentifierName(registrationOptions.Name));
+
+ MemberDeclarationSyntax MakeAction(TypeSyntax syntax)
+ {
+ return method
+ .WithParameterList(parameters.WithParameters(SeparatedList(parameters.Parameters.Concat(
+ new[] {Parameter(Identifier("handler")).WithType(syntax), registrationParameter}))))
+ .NormalizeWhitespace();
+ }
+
+ yield return MakeAction(CreateAsyncFunc(responseType, false, requestType));
+ yield return MakeAction(CreateAsyncFunc(responseType, true, requestType));
+
+ if (partialItems != null)
+ {
+ var partialTypeSyntax = ResolveTypeName(partialItems);
+ var partialItemsSyntax = GenericName("IEnumerable").WithTypeArgumentList(TypeArgumentList(SeparatedList(new[] {partialTypeSyntax})));
+
+ method = method.WithBody(GetPartialResultsRegistrationHandlerExpression(GetMethodName(handlerInterface), requestType, responseType, partialTypeSyntax,
+ registrationOptions));
+
+ yield return MakeAction(CreatePartialAction(requestType, partialItemsSyntax, true));
+ yield return MakeAction(CreatePartialAction(requestType, partialItemsSyntax, false));
+ if (capability != null)
+ {
+ method = method.WithBody(GetPartialResultsRegistrationHandlerExpression(GetMethodName(handlerInterface), requestType, responseType, partialTypeSyntax,
+ registrationOptions, capability));
+ yield return MakeAction(CreatePartialAction(requestType, partialItemsSyntax, capability));
+ }
+ }
+
+ if (partialItem != null)
+ {
+ var partialTypeSyntax = ResolveTypeName(partialItem);
+
+ method = method.WithBody(GetPartialResultRegistrationHandlerExpression(GetMethodName(handlerInterface), requestType, responseType, registrationOptions));
+
+ yield return MakeAction(CreatePartialAction(requestType, partialTypeSyntax, true));
+ yield return MakeAction(CreatePartialAction(requestType, partialTypeSyntax, false));
+ if (capability != null)
+ {
+ method = method.WithBody(GetPartialResultRegistrationHandlerExpression(GetMethodName(handlerInterface), requestType, responseType, registrationOptions,
+ capability));
+ yield return MakeAction(CreatePartialAction(requestType, partialTypeSyntax, capability));
+ }
+ }
+
+ if (capability != null)
+ {
+ method = method.WithBody(
+ GetRequestRegistrationHandlerExpression(GetMethodName(handlerInterface), requestType, responseType, registrationOptions, capability));
+ if (responseType.Name == "Unit")
+ {
+ method = method.WithBody(GetVoidRequestRegistrationHandlerExpression(GetMethodName(handlerInterface), requestType, registrationOptions, capability));
+ }
+
+ yield return MakeAction(CreateAsyncFunc(responseType, requestType, capability));
+ }
+ }
+ }
+
+ static IEnumerable GetRegistries(
+ AttributeData attributeData,
+ InterfaceDeclarationSyntax interfaceSyntax,
+ INamedTypeSymbol interfaceType,
+ TransformationContext context,
+ IProgress progress,
+ HashSet additionalUsings)
+ {
+ if (attributeData.ConstructorArguments[0].Values.Length > 0)
+ {
+ return attributeData.ConstructorArguments[0].Values.Select(z => z.Value).OfType()
+ .Select(ResolveTypeName);
+ }
+
+ if (interfaceType.ContainingNamespace.ToDisplayString().StartsWith("OmniSharp.Extensions.LanguageServer.Protocol"))
+ {
+ var attribute = interfaceType.GetAttributes().First(z => z.AttributeClass?.Name == "MethodAttribute");
+ if (attribute.ConstructorArguments.Length < 2)
+ {
+ progress.Report(Diagnostic.Create(GeneratorDiagnostics.MissingDirection, interfaceSyntax.Identifier.GetLocation()));
+ return Enumerable.Empty();
+ }
+
+ var direction = (int) interfaceType.GetAttributes().First(z => z.AttributeClass?.Name == "MethodAttribute").ConstructorArguments[1].Value;
+
+ /*
+ Unspecified = 0b0000,
+ ServerToClient = 0b0001,
+ ClientToServer = 0b0010,
+ Bidirectional = 0b0011
+ */
+ var maskedDirection = (0b0011 & direction);
+
+
+ if (maskedDirection == 1)
+ {
+ additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Client");
+ additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Client.Capabilities");
+ return new[] {LanguageProtocolServerToClient};
+ }
+
+ if (maskedDirection == 2)
+ {
+ additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Server");
+ return new[] {LanguageProtocolClientToServer};
+ }
+
+ if (maskedDirection == 3)
+ {
+ additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Client");
+ additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Client.Capabilities");
+ additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Server");
+ return new[] {LanguageProtocolClientToServer, LanguageProtocolServerToClient};
+ }
+ }
+
+ if (interfaceType.ContainingNamespace.ToDisplayString().StartsWith("OmniSharp.Extensions.DebugAdapter.Protocol"))
+ {
+ var attribute = interfaceType.GetAttributes().First(z => z.AttributeClass?.Name == "MethodAttribute");
+ if (attribute.ConstructorArguments.Length < 2)
+ {
+ progress.Report(Diagnostic.Create(GeneratorDiagnostics.MissingDirection, interfaceSyntax.Identifier.GetLocation()));
+ return Enumerable.Empty();
+ }
+
+ var direction = (int) interfaceType.GetAttributes().First(z => z.AttributeClass?.Name == "MethodAttribute").ConstructorArguments[1].Value;
+
+ /*
+ Unspecified = 0b0000,
+ ServerToClient = 0b0001,
+ ClientToServer = 0b0010,
+ Bidirectional = 0b0011
+ */
+ var maskedDirection = (0b0011 & direction);
+ additionalUsings.Add("OmniSharp.Extensions.DebugAdapter.Protocol");
+
+ if (maskedDirection == 1)
+ {
+ return new[] {DebugProtocolServerToClient};
+ }
+
+ if (maskedDirection == 2)
+ {
+ return new[] {DebugProtocolClientToServer};
+ }
+
+ if (maskedDirection == 3)
+ {
+ return new[] {DebugProtocolClientToServer, DebugProtocolServerToClient};
+ }
+ }
+
+ throw new NotImplementedException("Add inference logic here " + interfaceSyntax.Identifier.ToFullString());
+ }
+
+ private static NameSyntax LanguageProtocolServerToClient { get; } =
+ IdentifierName("ILanguageClientRegistry");
+
+ private static NameSyntax LanguageProtocolClientToServer { get; } =
+ IdentifierName("ILanguageServerRegistry");
+
+ private static NameSyntax DebugProtocolServerToClient { get; } =
+ IdentifierName("IDebugAdapterClientRegistry");
+
+ private static NameSyntax DebugProtocolClientToServer { get; } =
+ IdentifierName("IDebugAdapterServerRegistry");
+
+ public Task> GenerateAsync(TransformationContext context, IProgress progress, CancellationToken cancellationToken) =>
+ throw new NotImplementedException();
+ }
+
+ /*
+ IJsonRpcNotificationHandler
+ IJsonRpcRequestHandler
+ IJsonRpcRequestHandler
+ */
+}
diff --git a/src/JsonRpc.Generators/GenerateRequestMethodsGenerator.cs b/src/JsonRpc.Generators/GenerateRequestMethodsGenerator.cs
new file mode 100644
index 000000000..9ba5ead07
--- /dev/null
+++ b/src/JsonRpc.Generators/GenerateRequestMethodsGenerator.cs
@@ -0,0 +1,338 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using CodeGeneration.Roslyn;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
+using static OmniSharp.Extensions.JsonRpc.Generators.Helpers;
+
+namespace OmniSharp.Extensions.JsonRpc.Generators
+{
+ public class GenerateRequestMethodsGenerator : IRichCodeGenerator
+ {
+ private readonly AttributeData _attributeData;
+
+ public GenerateRequestMethodsGenerator(AttributeData attributeData)
+ {
+ _attributeData = attributeData;
+ }
+
+ public Task> GenerateAsync(TransformationContext context, IProgress progress, CancellationToken cancellationToken)
+ {
+ throw new NotImplementedException();
+ }
+
+ public Task GenerateRichAsync(TransformationContext context, IProgress progress, CancellationToken cancellationToken)
+ {
+ if (!(context.ProcessingNode is InterfaceDeclarationSyntax handlerInterface))
+ {
+ return Task.FromResult(new RichGenerationResult());
+ }
+
+ var methods = new List();
+ var additionalUsings = new HashSet() {
+ "System",
+ "System.Collections.Generic",
+ "System.Threading",
+ "System.Threading.Tasks",
+ "MediatR",
+ "Microsoft.Extensions.DependencyInjection",
+ };
+ var symbol = context.SemanticModel.GetDeclaredSymbol(handlerInterface);
+
+ var className = GetExtensionClassName(symbol);
+
+ var registries = GetProxies(_attributeData, handlerInterface, symbol, context, progress, additionalUsings);
+
+ if (_attributeData.ConstructorArguments[0].Values.Length == 0 && !symbol.ContainingNamespace.ToDisplayString().StartsWith("OmniSharp.Extensions.DebugAdapter.Protocol"))
+ {
+ progress.Report(Diagnostic.Create(GeneratorDiagnostics.NoResponseRouterProvided, handlerInterface.Identifier.GetLocation(), symbol.Name,
+ string.Join(", ", registries.Select(z => z.ToFullString()))));
+ }
+
+ foreach (var registry in registries)
+ {
+ if (IsNotification(symbol))
+ {
+ var requestType = GetRequestType(symbol);
+ methods.AddRange(HandleNotifications(handlerInterface, symbol, requestType, registry, additionalUsings));
+ }
+
+ if (IsRequest(symbol))
+ {
+ var requestType = GetRequestType(symbol);
+ var responseType = GetResponseType(symbol);
+ methods.AddRange(HandleRequests(handlerInterface, symbol, requestType, responseType, registry, additionalUsings));
+ }
+ }
+
+
+ var attributes = List(new[] {
+ AttributeList(SeparatedList(new[] {
+ Attribute(ParseName("System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute")),
+ Attribute(ParseName("System.Runtime.CompilerServices.CompilerGeneratedAttribute")),
+ }))
+ });
+ if (symbol.GetAttributes().Any(z => z.AttributeClass.Name == "GenerateRequestMethodsAttribute"))
+ {
+ attributes = List();
+ }
+
+ var existingUsings = context.CompilationUnitUsings
+ .Join(additionalUsings, z => z.Name.ToFullString(), z => z, (a, b) => b)
+ ;
+
+ var newUsings = additionalUsings
+ .Except(existingUsings)
+ .Select(z => UsingDirective(IdentifierName(z)))
+ ;
+ return Task.FromResult(new RichGenerationResult() {
+ Usings = List(newUsings),
+ Members = List(new[] {
+ NamespaceDeclaration(ParseName(symbol.ContainingNamespace.ToDisplayString()))
+ .WithMembers(List(new MemberDeclarationSyntax[] {
+ ClassDeclaration(className)
+ .WithAttributeLists(attributes)
+ .WithModifiers(TokenList(
+ Token(SyntaxKind.PublicKeyword),
+ Token(SyntaxKind.StaticKeyword),
+ Token(SyntaxKind.PartialKeyword)
+ ))
+ .WithMembers(List(methods))
+ .NormalizeWhitespace()
+ }))
+ })
+ });
+ }
+
+ public static IEnumerable GetProxies(
+ AttributeData attributeData,
+ InterfaceDeclarationSyntax interfaceSyntax,
+ INamedTypeSymbol interfaceType,
+ TransformationContext context,
+ IProgress progress,
+ HashSet additionalUsings)
+ {
+ if (attributeData.ConstructorArguments[0].Values.Length > 0)
+ {
+ return attributeData.ConstructorArguments[0].Values.Select(z => z.Value).OfType()
+ .Select(ResolveTypeName);
+ }
+
+ if (interfaceType.ContainingNamespace.ToDisplayString().StartsWith("OmniSharp.Extensions.LanguageServer.Protocol"))
+ {
+ var attribute = interfaceType.GetAttributes().First(z => z.AttributeClass?.Name == "MethodAttribute");
+ if (attribute.ConstructorArguments.Length < 2)
+ {
+ progress.Report(Diagnostic.Create(GeneratorDiagnostics.MissingDirection, interfaceSyntax.Identifier.GetLocation()));
+ return Enumerable.Empty();
+ }
+
+ var direction = (int) interfaceType.GetAttributes().First(z => z.AttributeClass?.Name == "MethodAttribute").ConstructorArguments[1].Value;
+
+ /*
+ Unspecified = 0b0000,
+ ServerToClient = 0b0001,
+ ClientToServer = 0b0010,
+ Bidirectional = 0b0011
+ */
+ var maskedDirection = (0b0011 & direction);
+
+ if (maskedDirection == 1)
+ {
+ additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Server");
+ return new[] {LanguageProtocolServerToClient};
+ }
+
+ if (maskedDirection == 2)
+ {
+ additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Client");
+ return new[] {LanguageProtocolClientToServer};
+ }
+
+ if (maskedDirection == 3)
+ {
+ additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Server");
+ additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Client");
+ return new[] {LanguageProtocolClientToServer, LanguageProtocolServerToClient};
+ }
+ }
+
+ if (interfaceType.ContainingNamespace.ToDisplayString().StartsWith("OmniSharp.Extensions.DebugAdapter.Protocol"))
+ {
+ var attribute = interfaceType.GetAttributes().First(z => z.AttributeClass?.Name == "MethodAttribute");
+ if (attribute.ConstructorArguments.Length < 2)
+ {
+ progress.Report(Diagnostic.Create(GeneratorDiagnostics.MissingDirection, interfaceSyntax.Identifier.GetLocation()));
+ return Enumerable.Empty();
+ }
+
+ var direction = (int) interfaceType.GetAttributes().First(z => z.AttributeClass?.Name == "MethodAttribute").ConstructorArguments[1].Value;
+
+ /*
+ Unspecified = 0b0000,
+ ServerToClient = 0b0001,
+ ClientToServer = 0b0010,
+ Bidirectional = 0b0011
+ */
+ var maskedDirection = (0b0011 & direction);
+ additionalUsings.Add("OmniSharp.Extensions.DebugAdapter.Protocol");
+
+ if (maskedDirection == 1)
+ {
+ return new[] {DebugProtocolServerToClient};
+ }
+
+ if (maskedDirection == 2)
+ {
+ return new[] {DebugProtocolClientToServer};
+ }
+
+ if (maskedDirection == 3)
+ {
+ return new[] {DebugProtocolClientToServer, DebugProtocolServerToClient};
+ }
+ }
+
+ throw new NotImplementedException("Add inference logic here " + interfaceSyntax.Identifier.ToFullString());
+ }
+
+ private static NameSyntax LanguageProtocolServerToClient { get; } =
+ ParseName("ILanguageServer");
+
+ private static NameSyntax LanguageProtocolClientToServer { get; } =
+ ParseName("ILanguageClient");
+
+ private static NameSyntax DebugProtocolServerToClient { get; } =
+ ParseName("IDebugAdapterServer");
+
+ private static NameSyntax DebugProtocolClientToServer { get; } =
+ ParseName("IDebugAdapterClient");
+
+ IEnumerable HandleNotifications(
+ InterfaceDeclarationSyntax handlerInterface,
+ INamedTypeSymbol interfaceType,
+ INamedTypeSymbol requestType,
+ NameSyntax registryType,
+ HashSet additionalUsings)
+ {
+ var methodName = GetSendMethodName(interfaceType, _attributeData);
+ var method = MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), methodName)
+ .WithModifiers(TokenList(
+ Token(SyntaxKind.PublicKeyword),
+ Token(SyntaxKind.StaticKeyword))
+ )
+ .WithExpressionBody(GetNotificationInvokeExpression())
+ .WithSemicolonToken(Token(SyntaxKind.SemicolonToken));
+
+ yield return method
+ .WithParameterList(
+ ParameterList(SeparatedList(new[] {
+ Parameter(Identifier("mediator"))
+ .WithType(registryType)
+ .WithModifiers(TokenList(Token(SyntaxKind.ThisKeyword))),
+ Parameter(Identifier("@params"))
+ .WithType(IdentifierName(requestType.Name))
+ })))
+ .NormalizeWhitespace();
+ }
+
+ IEnumerable HandleRequests(
+ InterfaceDeclarationSyntax handlerInterface,
+ INamedTypeSymbol interfaceType,
+ INamedTypeSymbol requestType,
+ INamedTypeSymbol responseType,
+ NameSyntax registryType,
+ HashSet additionalUsings)
+ {
+ var methodName = GetSendMethodName(interfaceType, _attributeData);
+ var parameterList = ParameterList(SeparatedList(new[] {
+ Parameter(Identifier("mediator"))
+ .WithType(registryType)
+ .WithModifiers(TokenList(Token(SyntaxKind.ThisKeyword))),
+ Parameter(Identifier("@params"))
+ .WithType(IdentifierName(requestType.Name)),
+ Parameter(Identifier("cancellationToken"))
+ .WithType(IdentifierName("CancellationToken"))
+ .WithDefault(EqualsValueClause(
+ LiteralExpression(SyntaxKind.DefaultLiteralExpression, Token(SyntaxKind.DefaultKeyword)))
+ )
+ }));
+ var partialItem = GetPartialItem(requestType);
+ if (partialItem != null)
+ {
+ additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Progress");
+ var partialTypeSyntax = ResolveTypeName(partialItem);
+ yield return MethodDeclaration(
+ GenericName(
+ Identifier("IRequestProgressObservable"))
+ .WithTypeArgumentList(
+ TypeArgumentList(
+ SeparatedList(
+ new TypeSyntax[] {
+ partialTypeSyntax,
+ ResolveTypeName(responseType)
+ }))),
+ Identifier(methodName)
+ )
+ .WithModifiers(TokenList(
+ Token(SyntaxKind.PublicKeyword),
+ Token(SyntaxKind.StaticKeyword))
+ )
+ .WithParameterList(parameterList)
+ .WithExpressionBody(GetPartialInvokeExpression(ResolveTypeName(responseType)))
+ .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
+ .NormalizeWhitespace();
+ yield break;
+ }
+
+ var partialItems = GetPartialItems(requestType);
+ if (partialItems != null)
+ {
+ additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Progress");
+ var partialTypeSyntax = ResolveTypeName(partialItems);
+ var partialItemsSyntax = GenericName("IEnumerable").WithTypeArgumentList(TypeArgumentList(SeparatedList(new[] {partialTypeSyntax})));
+ yield return MethodDeclaration(
+ GenericName(
+ Identifier("IRequestProgressObservable"))
+ .WithTypeArgumentList(
+ TypeArgumentList(
+ SeparatedList(
+ new TypeSyntax[] {
+ partialItemsSyntax,
+ ResolveTypeName(responseType)
+ }))),
+ Identifier(methodName)
+ )
+ .WithModifiers(TokenList(
+ Token(SyntaxKind.PublicKeyword),
+ Token(SyntaxKind.StaticKeyword))
+ )
+ .WithParameterList(parameterList)
+ .WithExpressionBody(GetPartialInvokeExpression(ResolveTypeName(responseType)))
+ .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
+ .NormalizeWhitespace();
+ ;
+ yield break;
+ }
+
+
+ var responseSyntax = responseType.Name == "Unit"
+ ? IdentifierName("Task") as NameSyntax
+ : GenericName("Task").WithTypeArgumentList(TypeArgumentList(SeparatedList(new[] {ResolveTypeName(responseType)})));
+ yield return MethodDeclaration(responseSyntax, methodName)
+ .WithModifiers(TokenList(
+ Token(SyntaxKind.PublicKeyword),
+ Token(SyntaxKind.StaticKeyword))
+ )
+ .WithParameterList(parameterList)
+ .WithExpressionBody(GetRequestInvokeExpression())
+ .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
+ .NormalizeWhitespace();
+ }
+ }
+}
diff --git a/src/JsonRpc.Generators/GeneratorDiagnostics.cs b/src/JsonRpc.Generators/GeneratorDiagnostics.cs
new file mode 100644
index 000000000..e00ef03e1
--- /dev/null
+++ b/src/JsonRpc.Generators/GeneratorDiagnostics.cs
@@ -0,0 +1,16 @@
+using Microsoft.CodeAnalysis;
+
+namespace OmniSharp.Extensions.JsonRpc.Generators
+{
+ static class GeneratorDiagnostics
+ {
+ public static DiagnosticDescriptor MissingDirection { get; } = new DiagnosticDescriptor("LSP1000", "Missing Direction",
+ "No direction defined for Language Server Protocol Handler", "JsonRPC", DiagnosticSeverity.Warning, true);
+
+ public static DiagnosticDescriptor NoHandlerRegistryProvided { get; } = new DiagnosticDescriptor("JRPC1000", "No Handler Registry Provided",
+ "No Handler Registry Provided for handler {0}.", "JsonRPC", DiagnosticSeverity.Warning, true);
+
+ public static DiagnosticDescriptor NoResponseRouterProvided { get; } = new DiagnosticDescriptor("JRPC1001", "No Response Router Provided",
+ "No Response Router Provided for handler {0}, defaulting to {1}.", "JsonRPC", DiagnosticSeverity.Warning, true);
+ }
+}
diff --git a/src/JsonRpc.Generators/Helpers.cs b/src/JsonRpc.Generators/Helpers.cs
new file mode 100644
index 000000000..401cb808f
--- /dev/null
+++ b/src/JsonRpc.Generators/Helpers.cs
@@ -0,0 +1,868 @@
+#nullable enable
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text.RegularExpressions;
+using System.Threading;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
+
+namespace OmniSharp.Extensions.JsonRpc.Generators
+{
+ static class Helpers
+ {
+ public static bool IsNotification(INamedTypeSymbol symbol)
+ {
+ return symbol.AllInterfaces.Any(z => z.Name == "IJsonRpcNotificationHandler");
+ }
+
+ public static bool IsRequest(INamedTypeSymbol symbol)
+ {
+ return symbol.AllInterfaces.Any(z => z.Name == "IJsonRpcRequestHandler");
+ }
+
+ public static ExpressionSyntax GetMethodName(InterfaceDeclarationSyntax interfaceSyntax)
+ {
+ var methodAttribute = interfaceSyntax.AttributeLists
+ .SelectMany(z => z.Attributes)
+ .First(z => z.Name.ToString() == "MethodAttribute" || z.Name.ToString() == "Method");
+
+ return methodAttribute.ArgumentList.Arguments[0].Expression;
+ }
+
+ public static INamedTypeSymbol GetRequestType(INamedTypeSymbol symbol)
+ {
+ var handlerInterface = symbol.AllInterfaces.First(z => z.Name == "IRequestHandler" && z.TypeArguments.Length == 2);
+ return handlerInterface.TypeArguments[0] as INamedTypeSymbol;
+ }
+
+ public static INamedTypeSymbol GetResponseType(INamedTypeSymbol symbol)
+ {
+ var handlerInterface = symbol.AllInterfaces.First(z => z.Name == "IRequestHandler" && z.TypeArguments.Length == 2);
+ return handlerInterface.TypeArguments[1] as INamedTypeSymbol;
+ }
+
+ public static INamedTypeSymbol? GetCapability(INamedTypeSymbol symbol)
+ {
+ var handlerInterface = symbol.AllInterfaces
+ .FirstOrDefault(z => z.Name == "ICapability" && z.TypeArguments.Length == 1);
+ return handlerInterface?.TypeArguments[0] as INamedTypeSymbol;
+ }
+
+ public static INamedTypeSymbol? GetRegistrationOptions(INamedTypeSymbol symbol)
+ {
+ var handlerInterface = symbol.AllInterfaces
+ .FirstOrDefault(z => z.Name == "IRegistration" && z.TypeArguments.Length == 1);
+ return handlerInterface?.TypeArguments[0] as INamedTypeSymbol;
+ }
+
+ public static INamedTypeSymbol? GetPartialItems(INamedTypeSymbol symbol)
+ {
+ var handlerInterface = symbol.AllInterfaces
+ .FirstOrDefault(z => z.Name == "IPartialItems" && z.TypeArguments.Length == 1);
+ return handlerInterface?.TypeArguments[0] as INamedTypeSymbol;
+ }
+
+ public static INamedTypeSymbol? GetPartialItem(INamedTypeSymbol symbol)
+ {
+ var handlerInterface = symbol.AllInterfaces
+ .FirstOrDefault(z => z.Name == "IPartialItem" && z.TypeArguments.Length == 1);
+ return handlerInterface?.TypeArguments[0] as INamedTypeSymbol;
+ }
+
+ public static GenericNameSyntax CreateAction(bool withCancellationToken, params ITypeSymbol[] types)
+ {
+ var typeArguments = types.Select(ResolveTypeName).ToList();
+ if (withCancellationToken)
+ {
+ typeArguments.Add(IdentifierName("CancellationToken"));
+ }
+
+ return GenericName(Identifier("Action"))
+ .WithTypeArgumentList(TypeArgumentList(SeparatedList(typeArguments)));
+ }
+
+ public static NameSyntax ResolveTypeName(ITypeSymbol symbol)
+ {
+ if (symbol is INamedTypeSymbol namedTypeSymbol)
+ {
+ if (namedTypeSymbol.IsGenericType)
+ {
+ // TODO: Fix for generic types
+ return ParseName(namedTypeSymbol.ToString());
+ }
+
+ // Assume that we're adding the correct namespaces.
+ return IdentifierName(namedTypeSymbol.Name);
+ }
+
+ return IdentifierName(symbol.Name);
+ }
+
+ public static GenericNameSyntax CreateAction(params ITypeSymbol[] types) => CreateAction(true, types);
+
+ public static GenericNameSyntax CreateAsyncAction(params ITypeSymbol[] types) => CreateAsyncFunc(null, true, types);
+
+ public static GenericNameSyntax CreateAsyncAction(bool withCancellationToken, params ITypeSymbol[] types) => CreateAsyncFunc(null, withCancellationToken, types);
+
+ public static GenericNameSyntax CreateAsyncFunc(ITypeSymbol? responseType, params ITypeSymbol[] types) => CreateAsyncFunc(responseType, true, types);
+
+ public static GenericNameSyntax CreateAsyncFunc(ITypeSymbol? responseType, bool withCancellationToken, params ITypeSymbol[] types)
+ {
+ var typeArguments = types.Select(ResolveTypeName).ToList();
+ if (withCancellationToken)
+ {
+ typeArguments.Add(IdentifierName("CancellationToken"));
+ }
+
+ if (responseType == null || responseType.Name == "Unit")
+ {
+ typeArguments.Add(IdentifierName("Task"));
+ }
+ else
+ {
+ typeArguments.Add(GenericName(Identifier("Task"), TypeArgumentList(SeparatedList(new TypeSyntax[] {
+ ResolveTypeName(responseType)
+ }))));
+ }
+
+ return GenericName(Identifier("Func"))
+ .WithTypeArgumentList(TypeArgumentList(SeparatedList(typeArguments)));
+ }
+
+ public static GenericNameSyntax CreatePartialAction(ITypeSymbol requestType, NameSyntax partialType, bool withCancellationToken, params ITypeSymbol[] types)
+ {
+ var typeArguments = new List() {
+ ResolveTypeName(requestType),
+ GenericName("IObserver").WithTypeArgumentList(TypeArgumentList(SeparatedList(new TypeSyntax[] {partialType}))),
+ };
+ typeArguments.AddRange(types.Select(ResolveTypeName));
+ if (withCancellationToken)
+ {
+ typeArguments.Add(IdentifierName("CancellationToken"));
+ }
+
+ return GenericName(Identifier("Action"))
+ .WithTypeArgumentList(TypeArgumentList(SeparatedList(typeArguments)));
+ }
+
+ public static GenericNameSyntax CreatePartialAction(ITypeSymbol requestType, NameSyntax partialType, params ITypeSymbol[] types) =>
+ CreatePartialAction(requestType, partialType, true, types);
+
+ private static ExpressionStatementSyntax EnsureRegistrationOptionsIsSet(NameSyntax registrationOptionsName, TypeSyntax registrationOptionsType)
+ {
+ return ExpressionStatement(AssignmentExpression(
+ SyntaxKind.CoalesceAssignmentExpression,
+ registrationOptionsName,
+ ObjectCreationExpression(registrationOptionsType)
+ .WithArgumentList(ArgumentList())));
+ }
+
+ private static InvocationExpressionSyntax AddHandler(params ArgumentSyntax[] arguments) => InvocationExpression(
+ MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName("registry"),
+ IdentifierName("AddHandler")))
+ .WithArgumentList(ArgumentList(SeparatedList(arguments)));
+
+ private static ArgumentListSyntax GetHandlerArgumentList() =>
+ ArgumentList(SeparatedList(new[] {
+ Argument(IdentifierName("handler"))
+ }));
+
+ private static ArgumentListSyntax GetRegistrationHandlerArgumentList(NameSyntax registrationOptionsName) =>
+ ArgumentList(SeparatedList(new[] {
+ Argument(IdentifierName("handler")),
+ Argument(registrationOptionsName)
+ }));
+
+ private static ArgumentListSyntax GetPartialResultArgumentList(NameSyntax responseName) =>
+ ArgumentList(
+ SeparatedList(
+ new[] {
+ Argument(IdentifierName("handler")),
+ Argument(InvocationExpression(MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName("_"),
+ GenericName(Identifier("GetService"))
+ .WithTypeArgumentList(TypeArgumentList(SingletonSeparatedList(IdentifierName("IProgressManager"))))
+ ))),
+ Argument(SimpleLambdaExpression(Parameter(Identifier("values")),
+ ObjectCreationExpression(responseName)
+ .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(IdentifierName("values")))))))
+ }));
+
+ private static ArgumentListSyntax GetPartialResultRegistrationArgumentList(NameSyntax registrationOptionsName, NameSyntax responseName) =>
+ ArgumentList(
+ SeparatedList(
+ new[] {
+ Argument(IdentifierName("handler")),
+ Argument(registrationOptionsName),
+ Argument(InvocationExpression(MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName("_"),
+ GenericName(Identifier("GetService"))
+ .WithTypeArgumentList(TypeArgumentList(SingletonSeparatedList(IdentifierName("IProgressManager"))))
+ ))),
+ Argument(SimpleLambdaExpression(Parameter(Identifier("values")),
+ ObjectCreationExpression(responseName)
+ .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(IdentifierName("values")))))))
+ }));
+
+ private static ArgumentListSyntax GetPartialItemsArgumentList(NameSyntax responseName) =>
+ ArgumentList(
+ SeparatedList(
+ new[] {
+ Argument(IdentifierName("handler")),
+ Argument(InvocationExpression(MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName("_"),
+ GenericName(Identifier("GetService"))
+ .WithTypeArgumentList(TypeArgumentList(SingletonSeparatedList(IdentifierName("IProgressManager"))))
+ ))),
+ Argument(SimpleLambdaExpression(Parameter(Identifier("values")),
+ ObjectCreationExpression(responseName)
+ .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(IdentifierName("values")))))))
+ }));
+
+ private static ArgumentListSyntax GetPartialItemsRegistrationArgumentList(NameSyntax registrationOptionsName, NameSyntax responseName) =>
+ ArgumentList(
+ SeparatedList(
+ new[] {
+ Argument(IdentifierName("handler")),
+ Argument(registrationOptionsName),
+ Argument(InvocationExpression(MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName("_"),
+ GenericName(Identifier("GetService"))
+ .WithTypeArgumentList(TypeArgumentList(SingletonSeparatedList(IdentifierName("IProgressManager"))))
+ ))),
+ Argument(SimpleLambdaExpression(Parameter(Identifier("values")),
+ ObjectCreationExpression(responseName)
+ .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(IdentifierName("values")))))))
+ }));
+
+ private static ObjectCreationExpressionSyntax CreateHandlerArgument(NameSyntax className, string innerClassName, params TypeSyntax[] genericArguments)
+ {
+ return ObjectCreationExpression(
+ QualifiedName(
+ className,
+ GenericName(innerClassName).WithTypeArgumentList(TypeArgumentList(SeparatedList(genericArguments)))
+ )
+ );
+ }
+
+ public static ArrowExpressionClauseSyntax GetNotificationCapabilityHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol capability)
+ {
+ var requestName = ResolveTypeName(requestType);
+ var capabilityName = ResolveTypeName(capability);
+ return ArrowExpressionClause(
+ AddHandler(
+ Argument(nameExpression),
+ Argument(
+ CreateHandlerArgument(
+ IdentifierName("LanguageProtocolDelegatingHandlers"),
+ "NotificationCapability",
+ requestName,
+ capabilityName
+ )
+ .WithArgumentList(GetHandlerArgumentList()))
+ )
+ );
+ }
+
+ public static BlockSyntax GetNotificationRegistrationHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol registrationOptions)
+ {
+ var requestName = ResolveTypeName(requestType);
+ var registrationOptionsName = ResolveTypeName(registrationOptions);
+ return Block(
+ EnsureRegistrationOptionsIsSet(IdentifierName("registrationOptions"), registrationOptionsName),
+ ReturnStatement(
+ AddHandler(
+ Argument(nameExpression),
+ Argument(
+ CreateHandlerArgument(
+ IdentifierName("LanguageProtocolDelegatingHandlers"),
+ "Notification",
+ requestName,
+ registrationOptionsName
+ )
+ .WithArgumentList(GetRegistrationHandlerArgumentList(IdentifierName("registrationOptions"))))
+ )
+ )
+ );
+ }
+
+ public static BlockSyntax GetNotificationRegistrationHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol registrationOptions,
+ ITypeSymbol capability)
+ {
+ var requestName = ResolveTypeName(requestType);
+ var registrationOptionsName = ResolveTypeName(registrationOptions);
+ var capabilityName = ResolveTypeName(capability);
+ return Block(
+ EnsureRegistrationOptionsIsSet(IdentifierName("registrationOptions"), registrationOptionsName),
+ ReturnStatement(
+ AddHandler(
+ Argument(nameExpression),
+ Argument(
+ CreateHandlerArgument(
+ IdentifierName("LanguageProtocolDelegatingHandlers"),
+ "Notification",
+ requestName,
+ capabilityName,
+ registrationOptionsName
+ )
+ .WithArgumentList(GetRegistrationHandlerArgumentList(IdentifierName("registrationOptions"))))
+ )
+ )
+ );
+ }
+
+ public static ArrowExpressionClauseSyntax GetRequestCapabilityHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol responseType,
+ ITypeSymbol capability)
+ {
+ var requestName = ResolveTypeName(requestType);
+ var responseName = ResolveTypeName(responseType);
+ var capabilityName = ResolveTypeName(capability);
+ return ArrowExpressionClause(
+ AddHandler(
+ Argument(nameExpression),
+ Argument(
+ CreateHandlerArgument(
+ IdentifierName("LanguageProtocolDelegatingHandlers"),
+ "RequestCapability",
+ requestName,
+ responseName,
+ capabilityName
+ )
+ .WithArgumentList(GetHandlerArgumentList()))
+ )
+ );
+ }
+
+ public static BlockSyntax GetRequestRegistrationHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol responseType,
+ ITypeSymbol registrationOptions)
+ {
+ var requestName = ResolveTypeName(requestType);
+ var responseName = ResolveTypeName(responseType);
+ var registrationOptionsName = ResolveTypeName(registrationOptions);
+ return Block(
+ EnsureRegistrationOptionsIsSet(IdentifierName("registrationOptions"), registrationOptionsName),
+ ReturnStatement(
+ AddHandler(
+ Argument(nameExpression),
+ Argument(
+ CreateHandlerArgument(
+ IdentifierName("LanguageProtocolDelegatingHandlers"),
+ "RequestRegistration",
+ requestName,
+ responseName,
+ registrationOptionsName
+ )
+ .WithArgumentList(GetRegistrationHandlerArgumentList(IdentifierName("registrationOptions"))))
+ )
+ )
+ );
+ }
+
+ public static BlockSyntax GetVoidRequestRegistrationHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol registrationOptions)
+ {
+ var requestName = ResolveTypeName(requestType);
+ var registrationOptionsName = ResolveTypeName(registrationOptions);
+ return Block(
+ EnsureRegistrationOptionsIsSet(IdentifierName("registrationOptions"), registrationOptionsName),
+ ReturnStatement(
+ AddHandler(
+ Argument(nameExpression),
+ Argument(
+ CreateHandlerArgument(
+ IdentifierName("LanguageProtocolDelegatingHandlers"),
+ "RequestRegistration",
+ requestName,
+ registrationOptionsName
+ )
+ .WithArgumentList(GetRegistrationHandlerArgumentList(IdentifierName("registrationOptions"))))
+ )
+ )
+ );
+ }
+
+ public static BlockSyntax GetRequestRegistrationHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol responseType,
+ ITypeSymbol registrationOptions,
+ ITypeSymbol capability)
+ {
+ var requestName = ResolveTypeName(requestType);
+ var responseName = ResolveTypeName(responseType);
+ var registrationOptionsName = ResolveTypeName(registrationOptions);
+ var capabilityName = ResolveTypeName(capability);
+ return Block(
+ EnsureRegistrationOptionsIsSet(IdentifierName("registrationOptions"), registrationOptionsName),
+ ReturnStatement(
+ AddHandler(
+ Argument(nameExpression),
+ Argument(
+ CreateHandlerArgument(
+ IdentifierName("LanguageProtocolDelegatingHandlers"),
+ "Request",
+ requestName,
+ responseName,
+ capabilityName,
+ registrationOptionsName
+ )
+ .WithArgumentList(GetRegistrationHandlerArgumentList(IdentifierName("registrationOptions"))))
+ )
+ )
+ );
+ }
+
+ public static BlockSyntax GetVoidRequestRegistrationHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol registrationOptions,
+ ITypeSymbol capability)
+ {
+ var requestName = ResolveTypeName(requestType);
+ var registrationOptionsName = ResolveTypeName(registrationOptions);
+ var capabilityName = ResolveTypeName(capability);
+ return Block(
+ EnsureRegistrationOptionsIsSet(IdentifierName("registrationOptions"), registrationOptionsName),
+ ReturnStatement(
+ AddHandler(
+ Argument(nameExpression),
+ Argument(
+ CreateHandlerArgument(
+ IdentifierName("LanguageProtocolDelegatingHandlers"),
+ "Request",
+ requestName,
+ capabilityName,
+ registrationOptionsName
+ )
+ .WithArgumentList(GetRegistrationHandlerArgumentList(IdentifierName("registrationOptions"))))
+ )
+ )
+ );
+ }
+
+ public static ArrowExpressionClauseSyntax GetRequestHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol responseType)
+ {
+ var requestName = ResolveTypeName(requestType);
+ var responseName = ResolveTypeName(responseType);
+ return ArrowExpressionClause(
+ AddHandler(
+ Argument(nameExpression),
+ Argument(
+ CreateHandlerArgument(
+ IdentifierName("LanguageProtocolDelegatingHandlers"),
+ "Request",
+ requestName,
+ responseName
+ )
+ .WithArgumentList(GetHandlerArgumentList()))
+ )
+ );
+ }
+
+ public static ArrowExpressionClauseSyntax GetRequestHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType)
+ {
+ var requestName = ResolveTypeName(requestType);
+ return ArrowExpressionClause(
+ AddHandler(
+ Argument(nameExpression),
+ Argument(
+ CreateHandlerArgument(
+ IdentifierName("LanguageProtocolDelegatingHandlers"),
+ "Request",
+ requestName
+ )
+ .WithArgumentList(GetHandlerArgumentList()))
+ )
+ );
+ }
+
+ public static ArrowExpressionClauseSyntax GetPartialResultCapabilityHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol responseType,
+ ITypeSymbol capability)
+ {
+ var requestName = ResolveTypeName(requestType);
+ var responseName = ResolveTypeName(responseType);
+ var capabilityName = ResolveTypeName(capability);
+ return ArrowExpressionClause(
+ AddHandler(
+ Argument(nameExpression),
+ Argument(
+ CreateHandlerArgument(
+ IdentifierName("LanguageProtocolDelegatingHandlers"),
+ "PartialResultCapability",
+ requestName,
+ responseName,
+ capabilityName
+ )
+ .WithArgumentList(GetPartialResultArgumentList(responseName)))
+ )
+ );
+ }
+
+ public static BlockSyntax GetPartialResultRegistrationHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol responseType,
+ ITypeSymbol registrationOptions)
+ {
+ var requestName = ResolveTypeName(requestType);
+ var responseName = ResolveTypeName(responseType);
+ var registrationOptionsName = ResolveTypeName(registrationOptions);
+ return Block(
+ EnsureRegistrationOptionsIsSet(IdentifierName("registrationOptions"), registrationOptionsName),
+ ReturnStatement(
+ AddHandler(
+ Argument(nameExpression),
+ Argument(
+ CreateHandlerArgument(
+ IdentifierName("LanguageProtocolDelegatingHandlers"),
+ "PartialResult",
+ requestName,
+ responseName,
+ registrationOptionsName
+ )
+ .WithArgumentList(GetPartialResultRegistrationArgumentList(IdentifierName("registrationOptions"), responseName)))
+ )
+ )
+ );
+ }
+
+ public static BlockSyntax GetPartialResultRegistrationHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol responseType,
+ ITypeSymbol registrationOptions,
+ ITypeSymbol capability)
+ {
+ var requestName = ResolveTypeName(requestType);
+ var responseName = ResolveTypeName(responseType);
+ var registrationOptionsName = ResolveTypeName(registrationOptions);
+ var capabilityName = ResolveTypeName(capability);
+ return Block(
+ EnsureRegistrationOptionsIsSet(IdentifierName("registrationOptions"), registrationOptionsName),
+ ReturnStatement(
+ AddHandler(
+ Argument(nameExpression),
+ Argument(
+ CreateHandlerArgument(
+ IdentifierName("LanguageProtocolDelegatingHandlers"),
+ "PartialResult",
+ requestName,
+ responseName,
+ capabilityName,
+ registrationOptionsName
+ )
+ .WithArgumentList(GetPartialResultRegistrationArgumentList(IdentifierName("registrationOptions"), responseName)))
+ )
+ )
+ );
+ }
+
+ public static ArrowExpressionClauseSyntax GetPartialResultHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol responseType)
+ {
+ var requestName = ResolveTypeName(requestType);
+ var responseName = ResolveTypeName(responseType);
+ return ArrowExpressionClause(
+ AddHandler(
+ Argument(nameExpression),
+ Argument(
+ CreateHandlerArgument(
+ IdentifierName("LanguageProtocolDelegatingHandlers"),
+ "PartialResult",
+ requestName,
+ responseName
+ )
+ .WithArgumentList(GetPartialResultArgumentList(responseName)))
+ )
+ );
+ }
+
+ public static ArrowExpressionClauseSyntax GetPartialResultsCapabilityHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol responseType,
+ NameSyntax itemName, ITypeSymbol capability)
+ {
+ var requestName = ResolveTypeName(requestType);
+ var responseName = ResolveTypeName(responseType);
+ var capabilityName = ResolveTypeName(capability);
+ return ArrowExpressionClause(
+ AddHandler(
+ Argument(nameExpression),
+ Argument(
+ SimpleLambdaExpression(
+ Parameter(
+ Identifier("_")),
+ CreateHandlerArgument(
+ IdentifierName("LanguageProtocolDelegatingHandlers"),
+ "PartialResultsCapability",
+ requestName,
+ responseName,
+ itemName,
+ capabilityName
+ )
+ .WithArgumentList(GetPartialItemsArgumentList(responseName)))
+ )
+ )
+ );
+ }
+
+ public static BlockSyntax GetPartialResultsRegistrationHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol responseType,
+ NameSyntax itemName, ITypeSymbol registrationOptions)
+ {
+ var requestName = ResolveTypeName(requestType);
+ var responseName = ResolveTypeName(responseType);
+ var registrationOptionsName = ResolveTypeName(registrationOptions);
+ return Block(
+ EnsureRegistrationOptionsIsSet(IdentifierName("registrationOptions"), registrationOptionsName),
+ ReturnStatement(
+ AddHandler(
+ Argument(nameExpression),
+ Argument(
+ SimpleLambdaExpression(
+ Parameter(
+ Identifier("_")),
+ CreateHandlerArgument(
+ IdentifierName("LanguageProtocolDelegatingHandlers"),
+ "PartialResults",
+ requestName,
+ responseName,
+ itemName,
+ registrationOptionsName
+ )
+ .WithArgumentList(GetPartialItemsRegistrationArgumentList(IdentifierName("registrationOptions"), responseName)))
+ )
+ ))
+ );
+ }
+
+ public static BlockSyntax GetPartialResultsRegistrationHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol responseType,
+ NameSyntax itemName, ITypeSymbol registrationOptions,
+ ITypeSymbol capability)
+ {
+ var requestName = ResolveTypeName(requestType);
+ var responseName = ResolveTypeName(responseType);
+ var registrationOptionsName = ResolveTypeName(registrationOptions);
+ var capabilityName = ResolveTypeName(capability);
+ return Block(
+ EnsureRegistrationOptionsIsSet(IdentifierName("registrationOptions"), registrationOptionsName),
+ ReturnStatement(
+ AddHandler(
+ Argument(nameExpression),
+ Argument(
+ SimpleLambdaExpression(
+ Parameter(
+ Identifier("_")),
+ CreateHandlerArgument(
+ IdentifierName("LanguageProtocolDelegatingHandlers"),
+ "PartialResults",
+ requestName,
+ responseName,
+ itemName,
+ capabilityName,
+ registrationOptionsName
+ )
+ .WithArgumentList(GetPartialItemsRegistrationArgumentList(IdentifierName("registrationOptions"), responseName)))
+ )
+ ))
+ );
+ }
+
+ public static ArrowExpressionClauseSyntax GetPartialResultsHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, NameSyntax itemName,
+ ITypeSymbol responseType)
+ {
+ var requestName = ResolveTypeName(requestType);
+ var responseName = ResolveTypeName(responseType);
+ return ArrowExpressionClause(
+ AddHandler(
+ Argument(nameExpression),
+ Argument(
+ SimpleLambdaExpression(
+ Parameter(
+ Identifier("_")),
+ CreateHandlerArgument(
+ IdentifierName("LanguageProtocolDelegatingHandlers"),
+ "PartialResults",
+ requestName,
+ responseName,
+ itemName
+ )
+ .WithArgumentList(GetPartialItemsArgumentList(responseName)))
+ )
+ )
+ );
+ }
+
+ public static ArrowExpressionClauseSyntax GetNotificationHandlerExpression(ExpressionSyntax nameExpression)
+ {
+ return ArrowExpressionClause(
+ InvocationExpression(MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName("registry"),
+ IdentifierName("AddHandler")
+ ))
+ .WithArgumentList(ArgumentList(SeparatedList(new SyntaxNodeOrToken[] {
+ Argument(nameExpression),
+ Token(SyntaxKind.CommaToken),
+ Argument(InvocationExpression(MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName("NotificationHandler"),
+ IdentifierName("For")
+ ))
+ .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(IdentifierName("handler"))))))
+ })))
+ );
+ }
+
+ public static ArrowExpressionClauseSyntax GetRequestHandlerExpression(ExpressionSyntax nameExpression)
+ {
+ return ArrowExpressionClause(
+ InvocationExpression(MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName("registry"),
+ IdentifierName("AddHandler")
+ ))
+ .WithArgumentList(ArgumentList(SeparatedList(new SyntaxNodeOrToken[] {
+ Argument(nameExpression),
+ Token(SyntaxKind.CommaToken),
+ Argument(InvocationExpression(MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName("RequestHandler"),
+ IdentifierName("For")
+ ))
+ .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(IdentifierName("handler"))))))
+ })))
+ );
+ }
+
+ public static ArrowExpressionClauseSyntax GetNotificationInvokeExpression()
+ {
+ return ArrowExpressionClause(
+ InvocationExpression(
+ MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName("mediator"),
+ IdentifierName("SendNotification")))
+ .WithArgumentList(
+ ArgumentList(
+ SeparatedList(new[] {
+ Argument(IdentifierName(@"@params"))
+ }))
+ )
+ );
+ }
+
+ public static ArrowExpressionClauseSyntax GetRequestInvokeExpression()
+ {
+ return ArrowExpressionClause(
+ InvocationExpression(
+ MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName("mediator"),
+ IdentifierName("SendRequest")))
+ .WithArgumentList(
+ ArgumentList(
+ SeparatedList(new[] {
+ Argument(IdentifierName(@"@params")),
+ Argument(IdentifierName("cancellationToken"))
+ }))
+ )
+ );
+ }
+
+ public static ArrowExpressionClauseSyntax GetPartialInvokeExpression(NameSyntax responseType)
+ {
+ return ArrowExpressionClause(
+ InvocationExpression(
+ MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName("mediator"),
+ IdentifierName("ProgressManager")),
+ IdentifierName("MonitorUntil")))
+ .WithArgumentList(
+ ArgumentList(
+ SeparatedList(
+ new [] {
+ Argument(
+ IdentifierName(@"@params")),
+ Argument(
+ SimpleLambdaExpression(
+ Parameter(Identifier("value")),
+ ObjectCreationExpression(responseType)
+ .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(IdentifierName("value"))))))),
+ Argument(IdentifierName("cancellationToken"))
+ }))));
+ }
+
+ public static string GetExtensionClassName(INamedTypeSymbol symbol)
+ {
+ return SpecialCasedHandlerFullName(symbol).Split('.').Last() + "Extensions";
+ ;
+ }
+
+ private static string SpecialCasedHandlerFullName(INamedTypeSymbol symbol)
+ {
+ return new Regex(@"(\w+)$")
+ .Replace(symbol.ToDisplayString() ?? string.Empty,
+ symbol.Name.Substring(1, symbol.Name.IndexOf("Handler", StringComparison.Ordinal) - 1))
+ ;
+ }
+
+ public static string SpecialCasedHandlerName(INamedTypeSymbol symbol)
+ {
+ var name = SpecialCasedHandlerFullName(symbol);
+ return name.Substring(name.LastIndexOf('.') + 1);
+ }
+
+ public static string GetOnMethodName(INamedTypeSymbol symbol, AttributeData attributeData)
+ {
+ var namedMethod = attributeData.NamedArguments
+ .Where(z => z.Key == "MethodName")
+ .Select(z => z.Value.Value)
+ .FirstOrDefault();
+ if (namedMethod is string value) return value;
+ return "On" + SpecialCasedHandlerName(symbol);
+ }
+
+ public static string GetSendMethodName(INamedTypeSymbol symbol, AttributeData attributeData)
+ {
+ var namedMethod = attributeData.NamedArguments
+ .Where(z => z.Key == "MethodName")
+ .Select(z => z.Value.Value)
+ .FirstOrDefault();
+ if (namedMethod is string value) return value;
+ var name = SpecialCasedHandlerName(symbol);
+ if (
+ name.StartsWith("Run")
+ // TODO: Change this next breaking change
+ // || name.StartsWith("Set")
+ // || name.StartsWith("Attach")
+ // || name.StartsWith("Read")
+ || name.StartsWith("Did")
+ || name.StartsWith("Log")
+ || name.StartsWith("Show")
+ || name.StartsWith("Register")
+ || name.StartsWith("Prepare")
+ || name.StartsWith("Publish")
+ || name.StartsWith("ApplyWorkspaceEdit")
+ || name.StartsWith("Unregister"))
+ {
+ return name;
+ }
+
+ if (name.EndsWith("Resolve", StringComparison.Ordinal))
+ {
+ return "Resolve" + name.Substring(0, name.IndexOf("Resolve", StringComparison.Ordinal));
+ }
+
+ return IsNotification(symbol) ? "Send" + name : "Request" + name;
+ }
+
+ private static string HandlerName(INamedTypeSymbol symbol)
+ {
+ var name = HandlerFullName(symbol);
+ return name.Substring(name.LastIndexOf('.') + 1);
+ }
+
+ private static string HandlerFullName(INamedTypeSymbol symbol)
+ {
+ return new Regex(@"(\w+)$")
+ .Replace(symbol.ToDisplayString() ?? string.Empty,
+ symbol.Name.Substring(1, symbol.Name.IndexOf("Handler", StringComparison.Ordinal) - 1));
+ }
+ }
+}
diff --git a/src/JsonRpc.Generators/JsonRpc.Generators.csproj b/src/JsonRpc.Generators/JsonRpc.Generators.csproj
new file mode 100644
index 000000000..3a1bf25f0
--- /dev/null
+++ b/src/JsonRpc.Generators/JsonRpc.Generators.csproj
@@ -0,0 +1,17 @@
+
+
+
+
+
+ netstandard2.0
+ OmniSharp.Extensions.JsonRpc.Generators
+ OmniSharp.Extensions.JsonRpc.Generators
+ false
+
+
+
+
+
+
+
+
diff --git a/src/JsonRpc/Generation/GenerateHandlerMethodsAttribute.cs b/src/JsonRpc/Generation/GenerateHandlerMethodsAttribute.cs
new file mode 100644
index 000000000..5c909df09
--- /dev/null
+++ b/src/JsonRpc/Generation/GenerateHandlerMethodsAttribute.cs
@@ -0,0 +1,24 @@
+using System;
+using System.Diagnostics;
+using CodeGeneration.Roslyn;
+
+namespace OmniSharp.Extensions.JsonRpc.Generation
+{
+ ///
+ /// Allows generating OnXyz handler methods for a given IJsonRpcHandler
+ ///
+ ///
+ /// Efforts will be made to make this available for consumers once source generators land
+ ///
+ [AttributeUsage(AttributeTargets.Interface)]
+ [CodeGenerationAttribute("OmniSharp.Extensions.JsonRpc.Generators.GenerateHandlerMethodsGenerator, OmniSharp.Extensions.JsonRpc.Generators")]
+ [Conditional("CodeGeneration")]
+ public class GenerateHandlerMethodsAttribute : Attribute
+ {
+ public GenerateHandlerMethodsAttribute(params Type[] registryTypes)
+ {
+ }
+
+ public string MethodName { get; set; }
+ }
+}
diff --git a/src/JsonRpc/Generation/GenerateRequestMethodsAttribute.cs b/src/JsonRpc/Generation/GenerateRequestMethodsAttribute.cs
new file mode 100644
index 000000000..2a3905ef1
--- /dev/null
+++ b/src/JsonRpc/Generation/GenerateRequestMethodsAttribute.cs
@@ -0,0 +1,24 @@
+using System;
+using System.Diagnostics;
+using CodeGeneration.Roslyn;
+
+namespace OmniSharp.Extensions.JsonRpc.Generation
+{
+ ///
+ /// Allows generating SendXyz/RequestXyz methods for a given IJsonRpcHandler
+ ///
+ ///
+ /// Efforts will be made to make this available for consumers once source generators land
+ ///
+ [AttributeUsage(AttributeTargets.Interface)]
+ [CodeGenerationAttribute("OmniSharp.Extensions.JsonRpc.Generators.GenerateRequestMethodsGenerator, OmniSharp.Extensions.JsonRpc.Generators")]
+ [Conditional("CodeGeneration")]
+ public class GenerateRequestMethodsAttribute : Attribute
+ {
+ public GenerateRequestMethodsAttribute(params Type[] proxyTypes)
+ {
+ }
+
+ public string MethodName { get; set; }
+ }
+}
diff --git a/src/JsonRpc/JsonRpc.csproj b/src/JsonRpc/JsonRpc.csproj
index 6c8e67ac4..8cb348ece 100644
--- a/src/JsonRpc/JsonRpc.csproj
+++ b/src/JsonRpc/JsonRpc.csproj
@@ -18,8 +18,12 @@
+
+
+ <_Parameter1>OmniSharp.Extensions.DebugAdapter, PublicKey=0024000004800000940000000602000000240000525341310004000001000100391db875e68eb4bfef49ce14313b9e13f2cd3cc89eb273bbe6c11a55044c7d4f566cf092e1c77ef9e7c75b1496ae7f95d925938f5a01793dd8d9f99ae0a7595779b71b971287d7d7b5960d052078d14f5ce1a85ea5c9fb2f59ac735ff7bc215cab469b7c3486006860bad6f4c3b5204ea2f28dd4e1d05e2cca462cfd593b9f9f
+
<_Parameter1>OmniSharp.Extensions.LanguageProtocol, PublicKey=0024000004800000940000000602000000240000525341310004000001000100391db875e68eb4bfef49ce14313b9e13f2cd3cc89eb273bbe6c11a55044c7d4f566cf092e1c77ef9e7c75b1496ae7f95d925938f5a01793dd8d9f99ae0a7595779b71b971287d7d7b5960d052078d14f5ce1a85ea5c9fb2f59ac735ff7bc215cab469b7c3486006860bad6f4c3b5204ea2f28dd4e1d05e2cca462cfd593b9f9f
diff --git a/src/Protocol/Document/ITypeDefinitionHandler.cs b/src/Protocol/Document/ITypeDefinitionHandler.cs
index 10d7d3021..9a807da24 100644
--- a/src/Protocol/Document/ITypeDefinitionHandler.cs
+++ b/src/Protocol/Document/ITypeDefinitionHandler.cs
@@ -4,6 +4,7 @@
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using OmniSharp.Extensions.JsonRpc;
+using OmniSharp.Extensions.JsonRpc.Generation;
using OmniSharp.Extensions.LanguageServer.Protocol.Client;
using OmniSharp.Extensions.LanguageServer.Protocol.Client.Capabilities;
using OmniSharp.Extensions.LanguageServer.Protocol.Models;
@@ -13,6 +14,7 @@
namespace OmniSharp.Extensions.LanguageServer.Protocol.Document
{
[Parallel, Method(TextDocumentNames.TypeDefinition, Direction.ClientToServer)]
+ [GenerateHandlerMethods, GenerateRequestMethods(typeof(ITextDocumentLanguageClient), typeof(ILanguageClient))]
public interface ITypeDefinitionHandler : IJsonRpcRequestHandler, IRegistration, ICapability { }
public abstract class TypeDefinitionHandler : ITypeDefinitionHandler
@@ -28,96 +30,4 @@ public TypeDefinitionHandler(TypeDefinitionRegistrationOptions registrationOptio
public virtual void SetCapability(TypeDefinitionCapability capability) => Capability = capability;
protected TypeDefinitionCapability Capability { get; private set; }
}
-
- public static class TypeDefinitionExtensions
- {
-public static ILanguageServerRegistry OnTypeDefinition(this ILanguageServerRegistry registry,
- Func>
- handler,
- TypeDefinitionRegistrationOptions registrationOptions)
- {
- registrationOptions ??= new TypeDefinitionRegistrationOptions();
- return registry.AddHandler(TextDocumentNames.TypeDefinition,
- new LanguageProtocolDelegatingHandlers.Request(handler, registrationOptions));
- }
-
-public static ILanguageServerRegistry OnTypeDefinition(this ILanguageServerRegistry registry,
- Func> handler,
- TypeDefinitionRegistrationOptions registrationOptions)
- {
- registrationOptions ??= new TypeDefinitionRegistrationOptions();
- return registry.AddHandler(TextDocumentNames.TypeDefinition,
- new LanguageProtocolDelegatingHandlers.RequestRegistration(handler, registrationOptions));
- }
-
-public static ILanguageServerRegistry OnTypeDefinition(this ILanguageServerRegistry registry,
- Func> handler,
- TypeDefinitionRegistrationOptions registrationOptions)
- {
- registrationOptions ??= new TypeDefinitionRegistrationOptions();
- return registry.AddHandler(TextDocumentNames.TypeDefinition,
- new LanguageProtocolDelegatingHandlers.RequestRegistration(handler, registrationOptions));
- }
-
-public static ILanguageServerRegistry OnTypeDefinition(this ILanguageServerRegistry registry,
- Action>, TypeDefinitionCapability,
- CancellationToken> handler,
- TypeDefinitionRegistrationOptions registrationOptions)
- {
- registrationOptions ??= new TypeDefinitionRegistrationOptions();
- return registry.AddHandler(TextDocumentNames.TypeDefinition,
- _ =>
- new LanguageProtocolDelegatingHandlers.PartialResults(handler,
- registrationOptions, _.GetService(), x => new LocationOrLocationLinks(x)));
- }
-
-public static ILanguageServerRegistry OnTypeDefinition(this ILanguageServerRegistry registry,
- Action>, TypeDefinitionCapability>
- handler,
- TypeDefinitionRegistrationOptions registrationOptions)
- {
- registrationOptions ??= new TypeDefinitionRegistrationOptions();
- return registry.AddHandler(TextDocumentNames.TypeDefinition,
- _ =>
- new LanguageProtocolDelegatingHandlers.PartialResults(handler,
- registrationOptions, _.GetService(), x => new LocationOrLocationLinks(x)));
- }
-
-public static ILanguageServerRegistry OnTypeDefinition(this ILanguageServerRegistry registry,
- Action>, CancellationToken> handler,
- TypeDefinitionRegistrationOptions registrationOptions)
- {
- registrationOptions ??= new TypeDefinitionRegistrationOptions();
- return registry.AddHandler(TextDocumentNames.TypeDefinition,
- _ =>
- new LanguageProtocolDelegatingHandlers.PartialResults(handler, registrationOptions,
- _.GetService(), x => new LocationOrLocationLinks(x)));
- }
-
-public static ILanguageServerRegistry OnTypeDefinition(this ILanguageServerRegistry registry,
- Action>> handler,
- TypeDefinitionRegistrationOptions registrationOptions)
- {
- registrationOptions ??= new TypeDefinitionRegistrationOptions();
- return registry.AddHandler(TextDocumentNames.TypeDefinition,
- _ =>
- new LanguageProtocolDelegatingHandlers.PartialResults(handler, registrationOptions,
- _.GetService(), x => new LocationOrLocationLinks(x)));
- }
-
- public static IRequestProgressObservable, LocationOrLocationLinks> RequestTypeDefinition(
- this ITextDocumentLanguageClient mediator,
- TypeDefinitionParams @params,
- CancellationToken cancellationToken = default)
- {
- return mediator.ProgressManager.MonitorUntil(@params, x => new LocationOrLocationLinks(x), cancellationToken);
- }
- }
}
diff --git a/src/Protocol/General/ILanguageProtocolInitializeHandler.cs b/src/Protocol/General/ILanguageProtocolInitializeHandler.cs
index 14f081489..b6d7a6d8c 100644
--- a/src/Protocol/General/ILanguageProtocolInitializeHandler.cs
+++ b/src/Protocol/General/ILanguageProtocolInitializeHandler.cs
@@ -2,6 +2,7 @@
using System.Threading;
using System.Threading.Tasks;
using OmniSharp.Extensions.JsonRpc;
+using OmniSharp.Extensions.JsonRpc.Generation;
using OmniSharp.Extensions.LanguageServer.Protocol.Client;
using OmniSharp.Extensions.LanguageServer.Protocol.Models;
using OmniSharp.Extensions.LanguageServer.Protocol.Server;
@@ -12,6 +13,7 @@ namespace OmniSharp.Extensions.LanguageServer.Protocol.General
/// InitializeError
///
[Serial, Method(GeneralNames.Initialize, Direction.ClientToServer)]
+ [GenerateHandlerMethods(typeof(ILanguageServerRegistry), MethodName = "OnLanguageProtocolInitialize"), GenerateRequestMethods(typeof(ILanguageClient), MethodName = "RequestLanguageProtocolInitialize")]
public interface ILanguageProtocolInitializeHandler : IJsonRpcRequestHandler
{
}
@@ -20,26 +22,4 @@ public abstract class LanguageProtocolInitializeHandler : ILanguageProtocolIniti
{
public abstract Task Handle(InitializeParams request, CancellationToken cancellationToken);
}
-
- public static class LanguageProtocolInitializeExtensions
- {
- public static ILanguageServerRegistry OnLanguageProtocolInitialize(this ILanguageServerRegistry registry,
- Func>
- handler)
- {
- return registry.AddHandler(GeneralNames.Initialize, RequestHandler.For(handler));
- }
-
- public static ILanguageServerRegistry OnLanguageProtocolInitialize(this ILanguageServerRegistry registry,
- Func> handler)
- {
- return registry.AddHandler(GeneralNames.Initialize, RequestHandler.For(handler));
- }
-
- public static Task RequestLanguageProtocolInitialize(this ILanguageClient mediator, InitializeParams @params,
- CancellationToken cancellationToken = default)
- {
- return mediator.SendRequest(@params, cancellationToken);
- }
- }
}
diff --git a/src/Protocol/General/IShutdownHandler.cs b/src/Protocol/General/IShutdownHandler.cs
index b3837006d..ca0696ddb 100644
--- a/src/Protocol/General/IShutdownHandler.cs
+++ b/src/Protocol/General/IShutdownHandler.cs
@@ -3,6 +3,7 @@
using System.Threading.Tasks;
using MediatR;
using OmniSharp.Extensions.JsonRpc;
+using OmniSharp.Extensions.JsonRpc.Generation;
using OmniSharp.Extensions.LanguageServer.Protocol.Client;
using OmniSharp.Extensions.LanguageServer.Protocol.Models;
using OmniSharp.Extensions.LanguageServer.Protocol.Server;
@@ -10,6 +11,7 @@
namespace OmniSharp.Extensions.LanguageServer.Protocol.General
{
[Serial, Method(GeneralNames.Shutdown, Direction.ClientToServer)]
+ [GenerateHandlerMethods, GenerateRequestMethods(typeof(ITextDocumentLanguageClient), typeof(ILanguageClient))]
public interface IShutdownHandler : IJsonRpcRequestHandler
{
}
@@ -25,33 +27,8 @@ public virtual async Task Handle(ShutdownParams request, CancellationToken
protected abstract Task Handle(CancellationToken cancellationToken);
}
- public static class ShutdownExtensions
+ public static partial class ShutdownExtensions
{
- public static ILanguageServerRegistry OnShutdown(this ILanguageServerRegistry registry,
- Func
- handler)
- {
- return registry.AddHandler(GeneralNames.Shutdown,
- RequestHandler.For(async (_, ct) => {
- await handler(_, ct);
- return Unit.Value;
- }));
- }
-
- public static ILanguageServerRegistry OnShutdown(this ILanguageServerRegistry registry, Func handler)
- {
- return registry.AddHandler(GeneralNames.Shutdown,
- RequestHandler.For(async (_, ct) => {
- await handler(_);
- return Unit.Value;
- }));
- }
-
- public static Task RequestShutdown(this ILanguageClient mediator, ShutdownParams @params, CancellationToken cancellationToken = default)
- {
- return mediator.SendRequest(@params, cancellationToken);
- }
-
public static Task RequestShutdown(this ILanguageClient mediator, CancellationToken cancellationToken = default)
{
return mediator.SendRequest(ShutdownParams.Instance, cancellationToken);
diff --git a/src/Protocol/IProgressHandler.cs b/src/Protocol/IProgressHandler.cs
index 469a820cc..c61a0e034 100644
--- a/src/Protocol/IProgressHandler.cs
+++ b/src/Protocol/IProgressHandler.cs
@@ -4,6 +4,7 @@
using System.Threading.Tasks;
using MediatR;
using OmniSharp.Extensions.JsonRpc;
+using OmniSharp.Extensions.JsonRpc.Generation;
using OmniSharp.Extensions.LanguageServer.Protocol.Client;
using OmniSharp.Extensions.LanguageServer.Protocol.Models;
using OmniSharp.Extensions.LanguageServer.Protocol.Progress;
@@ -12,6 +13,7 @@
namespace OmniSharp.Extensions.LanguageServer.Protocol
{
[Parallel, Method(GeneralNames.Progress, Direction.Bidirectional)]
+ [GenerateHandlerMethods, GenerateRequestMethods(typeof(IGeneralLanguageClient), typeof(ILanguageClient), typeof(IGeneralLanguageServer), typeof(ILanguageServer))]
public interface IProgressHandler : IJsonRpcNotificationHandler
{
}
@@ -21,66 +23,8 @@ public abstract class ProgressHandler : IProgressHandler
public abstract Task Handle(ProgressParams request, CancellationToken cancellationToken);
}
- public static class ProgressExtensions
+ public static partial class ProgressExtensions
{
-public static ILanguageServerRegistry OnProgress(this ILanguageServerRegistry registry,
- Action handler)
- {
- return registry.AddHandler(GeneralNames.Progress, NotificationHandler.For(handler));
- }
-
-public static ILanguageServerRegistry OnProgress(this ILanguageServerRegistry registry,
- Func handler)
- {
- return registry.AddHandler(GeneralNames.Progress, NotificationHandler.For(handler));
- }
-
-public static ILanguageServerRegistry OnProgress(this ILanguageServerRegistry registry,
- Action handler)
- {
- return registry.AddHandler(GeneralNames.Progress, NotificationHandler.For(handler));
- }
-
-public static ILanguageServerRegistry OnProgress(this ILanguageServerRegistry registry,
- Func handler)
- {
- return registry.AddHandler(GeneralNames.Progress, NotificationHandler.For(handler));
- }
-
-public static ILanguageClientRegistry OnProgress(this ILanguageClientRegistry registry,
- Action handler)
- {
- return registry.AddHandler(GeneralNames.Progress, NotificationHandler.For(handler));
- }
-
-public static ILanguageClientRegistry OnProgress(this ILanguageClientRegistry registry,
- Func handler)
- {
- return registry.AddHandler(GeneralNames.Progress, NotificationHandler.For(handler));
- }
-
-public static ILanguageClientRegistry OnProgress(this ILanguageClientRegistry registry,
- Action handler)
- {
- return registry.AddHandler(GeneralNames.Progress, NotificationHandler.For(handler));
- }
-
-public static ILanguageClientRegistry OnProgress(this ILanguageClientRegistry registry,
- Func handler)
- {
- return registry.AddHandler(GeneralNames.Progress, NotificationHandler.For(handler));
- }
-
- public static void SendProgress(this IGeneralLanguageClient registry, ProgressParams @params)
- {
- registry.SendNotification(@params);
- }
-
- public static void SendProgress(this IGeneralLanguageServer registry, ProgressParams @params)
- {
- registry.SendNotification(@params);
- }
-
public static IRequestProgressObservable RequestProgress(this IClientProxy requestRouter, IPartialItemRequest @params, Func factory, CancellationToken cancellationToken = default)
{
var resultToken = new ProgressToken(Guid.NewGuid().ToString());
diff --git a/src/Protocol/Protocol.csproj b/src/Protocol/Protocol.csproj
index 996e1cb95..1661083a9 100644
--- a/src/Protocol/Protocol.csproj
+++ b/src/Protocol/Protocol.csproj
@@ -8,6 +8,8 @@
+
+
<_Parameter1>OmniSharp.Extensions.LanguageServer, PublicKey=0024000004800000940000000602000000240000525341310004000001000100391db875e68eb4bfef49ce14313b9e13f2cd3cc89eb273bbe6c11a55044c7d4f566cf092e1c77ef9e7c75b1496ae7f95d925938f5a01793dd8d9f99ae0a7595779b71b971287d7d7b5960d052078d14f5ce1a85ea5c9fb2f59ac735ff7bc215cab469b7c3486006860bad6f4c3b5204ea2f28dd4e1d05e2cca462cfd593b9f9f
@@ -18,4 +20,7 @@
<_Parameter1>OmniSharp.Extensions.LanguageClient, PublicKey=0024000004800000940000000602000000240000525341310004000001000100391db875e68eb4bfef49ce14313b9e13f2cd3cc89eb273bbe6c11a55044c7d4f566cf092e1c77ef9e7c75b1496ae7f95d925938f5a01793dd8d9f99ae0a7595779b71b971287d7d7b5960d052078d14f5ce1a85ea5c9fb2f59ac735ff7bc215cab469b7c3486006860bad6f4c3b5204ea2f28dd4e1d05e2cca462cfd593b9f9f
+
+
+
diff --git a/test/Dap.Tests/Dap.Tests.csproj b/test/Dap.Tests/Dap.Tests.csproj
index 2498bb013..ca60129bc 100644
--- a/test/Dap.Tests/Dap.Tests.csproj
+++ b/test/Dap.Tests/Dap.Tests.csproj
@@ -9,6 +9,7 @@
+
diff --git a/test/Dap.Tests/FoundationTests.cs b/test/Dap.Tests/FoundationTests.cs
index 7ea1bc34b..fc5a1df19 100644
--- a/test/Dap.Tests/FoundationTests.cs
+++ b/test/Dap.Tests/FoundationTests.cs
@@ -250,7 +250,9 @@ public void Match(string description, params Func[] matchers)
new[] {
registrySub, Substitute.For(new Type[] {method.GetParameters()[1].ParameterType}, Array.Empty