diff --git a/case_utils/case_sparql_select/__init__.py b/case_utils/case_sparql_select/__init__.py index eaa98cb..8b25f30 100644 --- a/case_utils/case_sparql_select/__init__.py +++ b/case_utils/case_sparql_select/__init__.py @@ -49,74 +49,44 @@ _logger = logging.getLogger(os.path.basename(__file__)) -def main() -> None: - parser = argparse.ArgumentParser() - - # Configure debug logging before running parse_args, because there could be an error raised before the construction of the argument parser. - logging.basicConfig( - level=logging.DEBUG - if ("--debug" in sys.argv or "-d" in sys.argv) - else logging.INFO - ) - - parser.add_argument("-d", "--debug", action="store_true") - parser.add_argument( - "--built-version", - choices=tuple(built_version_choices_list), - default="case-" + CURRENT_CASE_VERSION, - help="Ontology version to use to supplement query, such as for subclass querying. Does not require networking to use. Default is most recent CASE release. Passing 'none' will mean no pre-built CASE ontology versions accompanying this tool will be included in the analysis.", - ) - parser.add_argument( - "--disallow-empty-results", - action="store_true", - help="Raise error if no results are returned for query.", - ) - parser.add_argument( - "--use-prefixes", - action="store_true", - help="Abbreviate node IDs according to graph's encoded prefixes. (This will use prefixes in the graph, not the query.)", - ) - parser.add_argument( - "out_table", - help="Expected extensions are .html for HTML tables, .md for Markdown tables, .csv for comma-separated values, and .tsv for tab-separated values.", - ) - parser.add_argument( - "in_sparql", - help="File containing a SPARQL SELECT query. Note that prefixes not mapped with a PREFIX statement will be mapped according to their first occurrence among input graphs.", - ) - parser.add_argument("in_graph", nargs="+") - args = parser.parse_args() - - graph = rdflib.Graph() - for in_graph_filename in args.in_graph: - graph.parse(in_graph_filename) - - # Inherit prefixes defined in input context dictionary. - nsdict = {k: v for (k, v) in graph.namespace_manager.namespaces()} - - select_query_text = None - with open(args.in_sparql, "r") as in_fh: - select_query_text = in_fh.read().strip() - _logger.debug("select_query_text = %r." % select_query_text) - - if "subClassOf" in select_query_text: - case_utils.ontology.load_subclass_hierarchy( - graph, built_version=args.built_version - ) - +def query_text_to_variables(select_query_text: str) -> typing.List[str]: # Build columns list from SELECT line. select_query_text_lines = select_query_text.split("\n") select_line = [ line for line in select_query_text_lines if line.startswith("SELECT ") ][0] variables = select_line.replace(" DISTINCT", "").replace("SELECT ", "").split(" ") + return variables + + +def graph_and_query_to_data_frame( + graph: rdflib.Graph, + select_query_text: str, + *args: typing.Any, + built_version: str = "case-" + CURRENT_CASE_VERSION, + disallow_empty_results: bool = False, + use_prefixes: bool = False, + **kwargs: typing.Any, +) -> pd.DataFrame: + # Inherit prefixes defined in input context dictionary. + nsdict = {k: v for (k, v) in graph.namespace_manager.namespaces()} + + # Avoid side-effects on input parameter. + if "subClassOf" in select_query_text: + _graph = rdflib.Graph() + _graph += graph + case_utils.ontology.load_subclass_hierarchy(_graph, built_version=built_version) + else: + _graph = graph + + variables = query_text_to_variables(select_query_text) tally = 0 records = [] select_query_object = rdflib.plugins.sparql.processor.prepareQuery( select_query_text, initNs=nsdict ) - for (row_no, row) in enumerate(graph.query(select_query_object)): + for (row_no, row) in enumerate(_graph.query(select_query_object)): tally = row_no + 1 record = [] for (column_no, column) in enumerate(row): @@ -131,7 +101,7 @@ def main() -> None: # .decode() is because hexlify returns bytes. column_value = binascii.hexlify(column.toPython()).decode() elif isinstance(column, rdflib.URIRef): - if args.use_prefixes: + if use_prefixes: column_value = graph.namespace_manager.qname(column.toPython()) else: column_value = column.toPython() @@ -141,39 +111,128 @@ def main() -> None: _logger.debug("row[0]column[%d] = %r." % (column_no, column_value)) record.append(column_value) records.append(record) + if tally == 0: - if args.disallow_empty_results: + if disallow_empty_results: raise ValueError("Failed to return any results.") df = pd.DataFrame(records, columns=variables) + return df + +def data_frame_to_table_text( + df: pd.DataFrame, + *args: typing.Any, + output_mode: str, + **kwargs: typing.Any, +) -> str: table_text: typing.Optional[str] = None - if args.out_table.endswith(".csv") or args.out_table.endswith(".tsv"): - # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_csv.html + + if output_mode in {"csv", "tsv"}: sep: str - if args.out_table.endswith(".csv"): + if output_mode == "csv": sep = "," - elif args.out_table.endswith(".tsv"): + elif output_mode == "tsv": sep = "\t" else: raise NotImplementedError( "Output extension not implemented in CSV-style output." ) table_text = df.to_csv(sep=sep) - elif args.out_table.endswith(".html"): + elif output_mode == "html": # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_html.html # Add CSS classes for CASE website Bootstrap support. table_text = df.to_html(classes=("table", "table-bordered", "table-condensed")) - elif args.out_table.endswith(".md"): + elif output_mode == "md": # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_markdown.html # https://pypi.org/project/tabulate/ # Assume Github-flavored Markdown. + table_text = df.to_markdown(tablefmt="github") - if table_text is None: - raise NotImplementedError( - "Unsupported output extension for output filename %r.", args.out_table - ) + else: + if table_text is None: + raise NotImplementedError("Unimplemented output mode: %r." % output_mode) + assert table_text is not None + + return table_text + + +def main() -> None: + parser = argparse.ArgumentParser() + + # Configure debug logging before running parse_args, because there could be an error raised before the construction of the argument parser. + logging.basicConfig( + level=logging.DEBUG + if ("--debug" in sys.argv or "-d" in sys.argv) + else logging.INFO + ) + + parser.add_argument("-d", "--debug", action="store_true") + parser.add_argument( + "--built-version", + choices=tuple(built_version_choices_list), + default="case-" + CURRENT_CASE_VERSION, + help="Ontology version to use to supplement query, such as for subclass querying. Does not require networking to use. Default is most recent CASE release. Passing 'none' will mean no pre-built CASE ontology versions accompanying this tool will be included in the analysis.", + ) + parser.add_argument( + "--disallow-empty-results", + action="store_true", + help="Raise error if no results are returned for query.", + ) + parser.add_argument( + "--use-prefixes", + action="store_true", + help="Abbreviate node IDs according to graph's encoded prefixes. (This will use prefixes in the graph, not the query.)", + ) + parser.add_argument( + "out_table", + help="Expected extensions are .html for HTML tables, .md for Markdown tables, .csv for comma-separated values, and .tsv for tab-separated values.", + ) + parser.add_argument( + "in_sparql", + help="File containing a SPARQL SELECT query. Note that prefixes not mapped with a PREFIX statement will be mapped according to their first occurrence among input graphs.", + ) + + parser.add_argument("in_graph", nargs="+") + args = parser.parse_args() + output_mode: str + if args.out_table.endswith(".csv"): + output_mode = "csv" + elif args.out_table.endswith(".html"): + output_mode = "html" + elif args.out_table.endswith(".json"): + output_mode = "json" + elif args.out_table.endswith(".md"): + output_mode = "md" + elif args.out_table.endswith(".tsv"): + output_mode = "tsv" + else: + raise NotImplementedError("Output file extension not implemented.") + + graph = rdflib.Graph() + for in_graph_filename in args.in_graph: + graph.parse(in_graph_filename) + + select_query_text: typing.Optional[str] = None + with open(args.in_sparql, "r") as in_fh: + select_query_text = in_fh.read().strip() + if select_query_text is None: + raise ValueError("Failed to load query.") + _logger.debug("select_query_text = %r." % select_query_text) + + df = graph_and_query_to_data_frame( + graph, + select_query_text, + built_version=args.built_version, + disallow_empty_results=args.disallow_empty_results is True, + use_prefixes=args.use_prefixes is True, + ) + + table_text = data_frame_to_table_text( + df, + output_mode=output_mode, + ) with open(args.out_table, "w") as out_fh: out_fh.write(table_text) if table_text[-1] != "\n":