Skip to content

refactor: some gen_ops with added code generator. #1063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions TensorFlow.NET.sln
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "helpers", "helpers", "{E1A5
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.UnitTest.RedistHolder", "helpers\Tensorflow.UnitTest.RedistHolder\Tensorflow.UnitTest.RedistHolder.csproj", "{62D543A2-8846-45A3-829B-5754B094A8E2}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.CodeGen", "Tensorflow.CodeGen\Tensorflow.CodeGen.csproj", "{BADBB104-2F03-4824-A249-803A871D8122}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -282,6 +284,24 @@ Global
{62D543A2-8846-45A3-829B-5754B094A8E2}.Release|x64.Build.0 = Release|Any CPU
{62D543A2-8846-45A3-829B-5754B094A8E2}.Release|x86.ActiveCfg = Release|Any CPU
{62D543A2-8846-45A3-829B-5754B094A8E2}.Release|x86.Build.0 = Release|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Debug|Any CPU.Build.0 = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Debug|x64.ActiveCfg = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Debug|x64.Build.0 = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Debug|x86.ActiveCfg = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Debug|x86.Build.0 = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.GPU|Any CPU.ActiveCfg = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.GPU|Any CPU.Build.0 = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.GPU|x64.ActiveCfg = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.GPU|x64.Build.0 = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.GPU|x86.ActiveCfg = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.GPU|x86.Build.0 = Debug|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Release|Any CPU.ActiveCfg = Release|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Release|Any CPU.Build.0 = Release|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Release|x64.ActiveCfg = Release|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Release|x64.Build.0 = Release|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Release|x86.ActiveCfg = Release|Any CPU
{BADBB104-2F03-4824-A249-803A871D8122}.Release|x86.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand All @@ -300,6 +320,7 @@ Global
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18} = {01A1787F-A9BE-4221-84E8-6360DD010AB6}
{7DEA8760-E401-4872-81F3-405F185A13A0} = {1B0918B9-65AD-4F34-A287-AF4597B27DBD}
{62D543A2-8846-45A3-829B-5754B094A8E2} = {E1A5D2B7-10AF-4876-85C0-7714EF274214}
{BADBB104-2F03-4824-A249-803A871D8122} = {E1A5D2B7-10AF-4876-85C0-7714EF274214}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {2DEAD3CC-486B-4918-A607-50B0DE7B114A}
Expand Down
263 changes: 263 additions & 0 deletions Tensorflow.CodeGen/DescriptionGenerator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
using Microsoft.CodeAnalysis.CSharp;
using Protobuf.Text;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection.Metadata.Ecma335;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading.Tasks;

namespace Tensorflow.CodeGen
{
public class DescriptionGenerator
{
private static readonly string replaceStrInner = "~~%~~";
private static readonly string replaceStrInnerQuotationMarks = "^%^";
Dictionary<string, Dictionary<string, string>> _opDescriptions = new Dictionary<string, Dictionary<string, string>>();
Dictionary<string, OpDef> _opDescriptionDefs = new Dictionary<string, OpDef>();
public DescriptionGenerator(string apiDefDirectory)
{
DirectoryInfo directory = new DirectoryInfo(apiDefDirectory);

int errors = 0;
foreach (FileInfo file in directory.GetFiles())
{
string target = file.Name.Split('.')[0].Split('_').Last();
OpDef op = null;
try
{
op = ReadOpDefs(file.FullName).Op[0];
}
catch
{
errors++;
continue;
}
_opDescriptionDefs[target] = op;
_opDescriptions[target] = new Dictionary<string, string>();
foreach (var arg in op.InputArg)
{
string argName = arg.Name;
var token = SyntaxFactory.ParseToken(argName);
if (token.IsKeyword())
{
argName = $"{argName}_";
}
_opDescriptions[target][argName] = arg.Description ?? "";
}
foreach (var arg in op.Attr)
{
var token = SyntaxFactory.ParseToken(arg.Name);
string realKey = arg.Name;
if (token.IsKeyword())
{
realKey += "_";
}
_opDescriptions[target][realKey] = arg.Description ?? "";
}
_opDescriptions[target]["SUMMARY"] = op.Summary ?? "";
_opDescriptions[target]["DESC"] = op.Description ?? "";
}
Console.WriteLine($"Warning: {errors} description files cannot be analyzed! Please revise it if " +
$"the failed files number is large, or ignore it.");
}

/// <summary>
///
/// </summary>
/// <param name="op"></param>
/// <param name="sb"></param>
public void AppendDescription(OpDef fullOp, StringBuilder sb)
{
var opName = fullOp.Name;
if(_opDescriptions.TryGetValue(opName, out var op))
{
var def = _opDescriptionDefs[opName];
sb.AppendLine("/// <summary>");
sb.AppendLine($"/// {op["SUMMARY"]}");
sb.AppendLine("/// </summary>");

string totalDesc = op["DESC"];
if (!string.IsNullOrEmpty(totalDesc))
{
totalDesc = totalDesc.Replace(replaceStrInnerQuotationMarks, "\"");
sb.AppendLine("/// <remarks>");
string[] lines = totalDesc.Split(replaceStrInner);
foreach (var line in lines)
{
sb.AppendLine($"/// {line}");
}
sb.AppendLine("/// </remarks>");
}

var argNames = GetInputArgNames(fullOp);
foreach (var argName in argNames)
{
if(op.TryGetValue(argName, out var desc))
{
desc = desc.Replace(replaceStrInnerQuotationMarks, "\"");
string[] lines = desc.Split(replaceStrInner);
sb.AppendLine($"/// <param name=\"{argName}\">");
foreach (var line in lines)
{
sb.AppendLine($"/// {line}");
}
sb.AppendLine("/// </param>");
}
else
{
sb.AppendLine($"/// <param name=\"{argName}\"></param>");
}
}

List<string> returnValueDescs = new();
foreach (var arg in def.OutputArg)
{
if (!string.IsNullOrEmpty(arg.Description))
{
returnValueDescs.Add($"{arg.Name}: {arg.Description}");
}
}
string returnValueDesc = "";
if (returnValueDescs.Count > 0)
{
returnValueDesc = string.Join(" && ", returnValueDescs);
}
sb.AppendLine($"/// <returns>{returnValueDesc}</returns>");
}
else
{
sb.AppendLine("/// <summary>");
sb.AppendLine($"///");
sb.AppendLine("/// </summary>");

var argNames = GetInputArgNames(fullOp);
foreach (var argName in argNames)
{
sb.AppendLine($"/// <param name=\"{argName}\"></param>");
}

sb.AppendLine($"/// <returns></returns>");
}
}

/// <summary>
///
/// </summary>
/// <param name="op">
/// </param>
/// <returns></returns>
/// <remarks></remarks>
public List<string> GetInputArgNames(OpDef op)
{
List<string> names = new();
foreach (var arg in op.InputArg)
{
string argName = arg.Name;
var token = SyntaxFactory.ParseToken(argName);
if (token.IsKeyword())
{
argName = $"{argName}_";
}
names.Add(argName);
}
var attrValueDic = Utils.GetAttrsDefaultValue(op, out var dynamicDefaultValues);
foreach (var (key, typeStr, value) in attrValueDic)
{
var token = SyntaxFactory.ParseToken(key);
string realKey = key;
if (token.IsKeyword())
{
realKey += "_";
}
names.Add(realKey);
}
return names;
}

private static OpList ReadOpDefs(string path)
{
var text = File.ReadAllText(path);
text = RemoveLintTags(text);
text = PreProcessText(text);

string pattern = @"<<END([\s\S]*?)END";

// 定义用于替换的字符串
string replaceStrPrefix = "\"";
string replaceStrSuffix = "\"";

// 将匹配到的文本段全部替换
string replacedText = Regex.Replace(text, pattern, match => {
string matchedText = match.Value;
string innerText = match.Groups[1].Value;
innerText = innerText.Replace("\"", replaceStrInnerQuotationMarks)
.Replace("\r\n", replaceStrInner).Replace("\n", replaceStrInner); // 替换内部换行符
return replaceStrPrefix + innerText + replaceStrSuffix; // 替换首尾
}, RegexOptions.Multiline);

var opDefs = new TextParser(TextParser.Settings.Default.WithIgnoreUnknownFields(true)).Parse<OpList>(replacedText);
return opDefs;
}

static string PreProcessText(string input)
{
int depth = 0;
int endBlockDepth = -1;
StringBuilder sb = new StringBuilder();
for (int i = 0; i < input.Length; i++)
{
char c = input[i];
if (c == '{')
{
depth++;
sb.Append(c);
}
else if (c == '}')
{
if (depth == endBlockDepth)
{
sb.Append("END\n");
endBlockDepth = -1;
}
sb.Append(c);
depth--;
}
else if (c == '<' && i + 5 < input.Length && input.Substring(i, 5) == "<<END")
{
endBlockDepth = depth;
sb.Append("<<END");
i += 4;
}
else if (c == 'E' && i + 3 < input.Length && input.Substring(i, 3) == "END")
{
endBlockDepth = -1;
sb.Append("END");
i += 2;
}
else
{
sb.Append(c);
}
}

string output = sb.ToString();
return output;
}

static string RemoveLintTags(string input)
{
string[] lines = input.Split(new[] { "\r\n", "\r", "\n" }, StringSplitOptions.None);
StringBuilder sb = new StringBuilder();
foreach (string line in lines)
{
if (!line.TrimStart().StartsWith("# LINT"))
{
sb.AppendLine(line);
}
}
return sb.ToString().TrimEnd();
}
}
}
Loading