Skip to content

Commit 98d4c6b

Browse files
committed
Refactor case_sparql_select code
This patch is code-motion to give function names to chunks of `case_sparql_select:main`. Some upcoming patch series are going to add features that, when taken together, introduce non-trivial parameter-value cross-dependencies. Moving functionality to functions enables combinatoric testing in a `pytest` space, rather than resorting to copying, pasting, and tweaking many Makefile lines. A future patch series will add the `pytest` script. Signed-off-by: Alex Nelson <alexander.nelson@nist.gov>
1 parent 83ec9ac commit 98d4c6b

File tree

1 file changed

+127
-68
lines changed

1 file changed

+127
-68
lines changed

case_utils/case_sparql_select/__init__.py

Lines changed: 127 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -49,74 +49,44 @@
4949
_logger = logging.getLogger(os.path.basename(__file__))
5050

5151

52-
def main() -> None:
53-
parser = argparse.ArgumentParser()
54-
55-
# Configure debug logging before running parse_args, because there could be an error raised before the construction of the argument parser.
56-
logging.basicConfig(
57-
level=logging.DEBUG
58-
if ("--debug" in sys.argv or "-d" in sys.argv)
59-
else logging.INFO
60-
)
61-
62-
parser.add_argument("-d", "--debug", action="store_true")
63-
parser.add_argument(
64-
"--built-version",
65-
choices=tuple(built_version_choices_list),
66-
default="case-" + CURRENT_CASE_VERSION,
67-
help="Ontology version to use to supplement query, such as for subclass querying. Does not require networking to use. Default is most recent CASE release. Passing 'none' will mean no pre-built CASE ontology versions accompanying this tool will be included in the analysis.",
68-
)
69-
parser.add_argument(
70-
"--disallow-empty-results",
71-
action="store_true",
72-
help="Raise error if no results are returned for query.",
73-
)
74-
parser.add_argument(
75-
"--use-prefixes",
76-
action="store_true",
77-
help="Abbreviate node IDs according to graph's encoded prefixes. (This will use prefixes in the graph, not the query.)",
78-
)
79-
parser.add_argument(
80-
"out_table",
81-
help="Expected extensions are .html for HTML tables, .md for Markdown tables, .csv for comma-separated values, and .tsv for tab-separated values.",
82-
)
83-
parser.add_argument(
84-
"in_sparql",
85-
help="File containing a SPARQL SELECT query. Note that prefixes not mapped with a PREFIX statement will be mapped according to their first occurrence among input graphs.",
86-
)
87-
parser.add_argument("in_graph", nargs="+")
88-
args = parser.parse_args()
89-
90-
graph = rdflib.Graph()
91-
for in_graph_filename in args.in_graph:
92-
graph.parse(in_graph_filename)
93-
94-
# Inherit prefixes defined in input context dictionary.
95-
nsdict = {k: v for (k, v) in graph.namespace_manager.namespaces()}
96-
97-
select_query_text = None
98-
with open(args.in_sparql, "r") as in_fh:
99-
select_query_text = in_fh.read().strip()
100-
_logger.debug("select_query_text = %r." % select_query_text)
101-
102-
if "subClassOf" in select_query_text:
103-
case_utils.ontology.load_subclass_hierarchy(
104-
graph, built_version=args.built_version
105-
)
106-
52+
def query_text_to_variables(select_query_text: str) -> typing.List[str]:
10753
# Build columns list from SELECT line.
10854
select_query_text_lines = select_query_text.split("\n")
10955
select_line = [
11056
line for line in select_query_text_lines if line.startswith("SELECT ")
11157
][0]
11258
variables = select_line.replace(" DISTINCT", "").replace("SELECT ", "").split(" ")
59+
return variables
60+
61+
62+
def graph_and_query_to_data_frame(
63+
graph: rdflib.Graph,
64+
select_query_text: str,
65+
*args: typing.Any,
66+
built_version: str = "case-" + CURRENT_CASE_VERSION,
67+
disallow_empty_results: bool = False,
68+
use_prefixes: bool = False,
69+
**kwargs: typing.Any,
70+
) -> pd.DataFrame:
71+
# Inherit prefixes defined in input context dictionary.
72+
nsdict = {k: v for (k, v) in graph.namespace_manager.namespaces()}
73+
74+
# Avoid side-effects on input parameter.
75+
if "subClassOf" in select_query_text:
76+
_graph = rdflib.Graph()
77+
_graph += graph
78+
case_utils.ontology.load_subclass_hierarchy(_graph, built_version=built_version)
79+
else:
80+
_graph = graph
81+
82+
variables = query_text_to_variables(select_query_text)
11383

11484
tally = 0
11585
records = []
11686
select_query_object = rdflib.plugins.sparql.processor.prepareQuery(
11787
select_query_text, initNs=nsdict
11888
)
119-
for (row_no, row) in enumerate(graph.query(select_query_object)):
89+
for (row_no, row) in enumerate(_graph.query(select_query_object)):
12090
tally = row_no + 1
12191
record = []
12292
for (column_no, column) in enumerate(row):
@@ -131,7 +101,7 @@ def main() -> None:
131101
# .decode() is because hexlify returns bytes.
132102
column_value = binascii.hexlify(column.toPython()).decode()
133103
elif isinstance(column, rdflib.URIRef):
134-
if args.use_prefixes:
104+
if use_prefixes:
135105
column_value = graph.namespace_manager.qname(column.toPython())
136106
else:
137107
column_value = column.toPython()
@@ -141,39 +111,128 @@ def main() -> None:
141111
_logger.debug("row[0]column[%d] = %r." % (column_no, column_value))
142112
record.append(column_value)
143113
records.append(record)
114+
144115
if tally == 0:
145-
if args.disallow_empty_results:
116+
if disallow_empty_results:
146117
raise ValueError("Failed to return any results.")
147118

148119
df = pd.DataFrame(records, columns=variables)
120+
return df
121+
149122

123+
def data_frame_to_table_text(
124+
df: pd.DataFrame,
125+
*args: typing.Any,
126+
output_mode: str,
127+
**kwargs: typing.Any,
128+
) -> str:
150129
table_text: typing.Optional[str] = None
151-
if args.out_table.endswith(".csv") or args.out_table.endswith(".tsv"):
152-
# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_csv.html
130+
131+
if output_mode in {"csv", "tsv"}:
153132
sep: str
154-
if args.out_table.endswith(".csv"):
133+
if output_mode == "csv":
155134
sep = ","
156-
elif args.out_table.endswith(".tsv"):
135+
elif output_mode == "tsv":
157136
sep = "\t"
158137
else:
159138
raise NotImplementedError(
160139
"Output extension not implemented in CSV-style output."
161140
)
162141
table_text = df.to_csv(sep=sep)
163-
elif args.out_table.endswith(".html"):
142+
elif output_mode == "html":
164143
# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_html.html
165144
# Add CSS classes for CASE website Bootstrap support.
166145
table_text = df.to_html(classes=("table", "table-bordered", "table-condensed"))
167-
elif args.out_table.endswith(".md"):
146+
elif output_mode == "md":
168147
# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_markdown.html
169148
# https://pypi.org/project/tabulate/
170149
# Assume Github-flavored Markdown.
150+
171151
table_text = df.to_markdown(tablefmt="github")
172-
if table_text is None:
173-
raise NotImplementedError(
174-
"Unsupported output extension for output filename %r.", args.out_table
175-
)
152+
else:
153+
if table_text is None:
154+
raise NotImplementedError("Unimplemented output mode: %r." % output_mode)
155+
assert table_text is not None
156+
157+
return table_text
158+
159+
160+
def main() -> None:
161+
parser = argparse.ArgumentParser()
162+
163+
# Configure debug logging before running parse_args, because there could be an error raised before the construction of the argument parser.
164+
logging.basicConfig(
165+
level=logging.DEBUG
166+
if ("--debug" in sys.argv or "-d" in sys.argv)
167+
else logging.INFO
168+
)
169+
170+
parser.add_argument("-d", "--debug", action="store_true")
171+
parser.add_argument(
172+
"--built-version",
173+
choices=tuple(built_version_choices_list),
174+
default="case-" + CURRENT_CASE_VERSION,
175+
help="Ontology version to use to supplement query, such as for subclass querying. Does not require networking to use. Default is most recent CASE release. Passing 'none' will mean no pre-built CASE ontology versions accompanying this tool will be included in the analysis.",
176+
)
177+
parser.add_argument(
178+
"--disallow-empty-results",
179+
action="store_true",
180+
help="Raise error if no results are returned for query.",
181+
)
182+
parser.add_argument(
183+
"--use-prefixes",
184+
action="store_true",
185+
help="Abbreviate node IDs according to graph's encoded prefixes. (This will use prefixes in the graph, not the query.)",
186+
)
187+
parser.add_argument(
188+
"out_table",
189+
help="Expected extensions are .html for HTML tables, .md for Markdown tables, .csv for comma-separated values, and .tsv for tab-separated values.",
190+
)
191+
parser.add_argument(
192+
"in_sparql",
193+
help="File containing a SPARQL SELECT query. Note that prefixes not mapped with a PREFIX statement will be mapped according to their first occurrence among input graphs.",
194+
)
195+
196+
parser.add_argument("in_graph", nargs="+")
197+
args = parser.parse_args()
176198

199+
output_mode: str
200+
if args.out_table.endswith(".csv"):
201+
output_mode = "csv"
202+
elif args.out_table.endswith(".html"):
203+
output_mode = "html"
204+
elif args.out_table.endswith(".json"):
205+
output_mode = "json"
206+
elif args.out_table.endswith(".md"):
207+
output_mode = "md"
208+
elif args.out_table.endswith(".tsv"):
209+
output_mode = "tsv"
210+
else:
211+
raise NotImplementedError("Output file extension not implemented.")
212+
213+
graph = rdflib.Graph()
214+
for in_graph_filename in args.in_graph:
215+
graph.parse(in_graph_filename)
216+
217+
select_query_text: typing.Optional[str] = None
218+
with open(args.in_sparql, "r") as in_fh:
219+
select_query_text = in_fh.read().strip()
220+
if select_query_text is None:
221+
raise ValueError("Failed to load query.")
222+
_logger.debug("select_query_text = %r." % select_query_text)
223+
224+
df = graph_and_query_to_data_frame(
225+
graph,
226+
select_query_text,
227+
built_version=args.built_version,
228+
disallow_empty_results=args.disallow_empty_results is True,
229+
use_prefixes=args.use_prefixes is True,
230+
)
231+
232+
table_text = data_frame_to_table_text(
233+
df,
234+
output_mode=output_mode,
235+
)
177236
with open(args.out_table, "w") as out_fh:
178237
out_fh.write(table_text)
179238
if table_text[-1] != "\n":

0 commit comments

Comments
 (0)