Skip to content

Commit e025ee0

Browse files
authored
Merge pull request #241 from djarecka/enh/result_verb
adding verbose as an optional argument to Task.result
2 parents 5c5991d + 52dd86e commit e025ee0

File tree

3 files changed

+181
-32
lines changed

3 files changed

+181
-32
lines changed

pydra/engine/core.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -532,29 +532,37 @@ def done(self):
532532
return True
533533
return False
534534

535-
def _combined_output(self):
535+
def _combined_output(self, return_inputs=False):
536536
combined_results = []
537537
for (gr, ind_l) in self.state.final_combined_ind_mapping.items():
538-
combined_results.append([])
538+
combined_results_gr = []
539539
for ind in ind_l:
540540
result = load_result(self.checksum_states(ind), self.cache_locations)
541541
if result is None:
542542
return None
543-
combined_results[gr].append(result)
543+
if return_inputs is True or return_inputs == "val":
544+
result = (self.state.states_val[ind], result)
545+
elif return_inputs == "ind":
546+
result = (self.state.states_ind[ind], result)
547+
combined_results_gr.append(result)
548+
combined_results.append(combined_results_gr)
544549
if len(combined_results) == 1 and self.state.splitter_rpn_final == []:
545550
# in case it's full combiner, removing the nested structure
546551
return combined_results[0]
547552
else:
548553
return combined_results
549554

550-
def result(self, state_index=None):
555+
def result(self, state_index=None, return_inputs=False):
551556
"""
552557
Retrieve the outcomes of this particular task.
553558
554559
Parameters
555560
----------
556-
state_index :
557-
TODO
561+
state_index : :obj: `int`
562+
index of the element for task with splitter and multiple states
563+
return_inputs : :obj: `bool`, :obj:`str`
564+
if True or "val" result is returned together with values of the input fields,
565+
if "ind" result is returned together with indices of the input fields
558566
559567
Returns
560568
-------
@@ -567,28 +575,50 @@ def result(self, state_index=None):
567575
if state_index is None:
568576
# if state_index=None, collecting all results
569577
if self.state.combiner:
570-
return self._combined_output()
578+
return self._combined_output(return_inputs=return_inputs)
571579
else:
572580
results = []
573581
for checksum in self.checksum_states():
574582
result = load_result(checksum, self.cache_locations)
575583
if result is None:
576584
return None
577585
results.append(result)
578-
return results
586+
if return_inputs is True or return_inputs == "val":
587+
return list(zip(self.state.states_val, results))
588+
elif return_inputs == "ind":
589+
return list(zip(self.state.states_ind, results))
590+
else:
591+
return results
579592
else: # state_index is not None
580593
if self.state.combiner:
581-
return self._combined_output()[state_index]
594+
return self._combined_output(return_inputs=return_inputs)[
595+
state_index
596+
]
582597
result = load_result(
583598
self.checksum_states(state_index), self.cache_locations
584599
)
585-
return result
600+
if return_inputs is True or return_inputs == "val":
601+
return (self.state.states_val[state_index], result)
602+
elif return_inputs == "ind":
603+
return (self.state.states_ind[state_index], result)
604+
else:
605+
return result
586606
else:
587607
if state_index is not None:
588608
raise ValueError("Task does not have a state")
589609
checksum = self.checksum
590610
result = load_result(checksum, self.cache_locations)
591-
return result
611+
if return_inputs is True or return_inputs == "val":
612+
inputs_val = {
613+
f"{self.name}.{inp}": getattr(self.inputs, inp)
614+
for inp in self.input_names
615+
}
616+
return (inputs_val, result)
617+
elif return_inputs == "ind":
618+
inputs_ind = {f"{self.name}.{inp}": None for inp in self.input_names}
619+
return (inputs_ind, result)
620+
else:
621+
return result
592622

593623
def _reset(self):
594624
"""Reset the connections between inputs and LazyFields."""

pydra/engine/tests/test_node_task.py

Lines changed: 113 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,18 @@ def test_task_nostate_1(plugin):
368368
# checking the results
369369
results = nn.result()
370370
assert results.output.out == 5
371+
# checking the return_inputs option, either is return_inputs is True, or "val",
372+
# it should give values of inputs that corresponds to the specific element
373+
results_verb = nn.result(return_inputs=True)
374+
results_verb_val = nn.result(return_inputs="val")
375+
assert results_verb[0] == results_verb_val[0] == {"NA.a": 3}
376+
assert results_verb[1].output.out == results_verb_val[1].output.out == 5
377+
# checking the return_inputs option return_inputs="ind"
378+
# it should give indices of inputs (instead of values) for each element
379+
results_verb_ind = nn.result(return_inputs="ind")
380+
assert results_verb_ind[0] == {"NA.a": None}
381+
assert results_verb_ind[1].output.out == 5
382+
371383
# checking the output_dir
372384
assert nn.output_dir.exists()
373385

@@ -709,6 +721,22 @@ def test_task_state_1(plugin):
709721
expected = [({"NA.a": 3}, 5), ({"NA.a": 5}, 7)]
710722
for i, res in enumerate(expected):
711723
assert results[i].output.out == res[1]
724+
725+
# checking the return_inputs option, either return_inputs is True or "val",
726+
# it should give values of inputs that corresponds to the specific element
727+
results_verb = nn.result(return_inputs=True)
728+
results_verb_val = nn.result(return_inputs="val")
729+
for i, res in enumerate(expected):
730+
assert (results_verb[i][0], results_verb[i][1].output.out) == res
731+
assert (results_verb_val[i][0], results_verb_val[i][1].output.out) == res
732+
733+
# checking the return_inputs option return_inputs="ind"
734+
# it should give indices of inputs (instead of values) for each element
735+
results_verb_ind = nn.result(return_inputs="ind")
736+
expected_ind = [({"NA.a": 0}, 5), ({"NA.a": 1}, 7)]
737+
for i, res in enumerate(expected_ind):
738+
assert (results_verb_ind[i][0], results_verb_ind[i][1].output.out) == res
739+
712740
# checking the output_dir
713741
assert nn.output_dir
714742
for odir in nn.output_dir:
@@ -737,13 +765,14 @@ def test_task_state_1a(plugin):
737765

738766

739767
@pytest.mark.parametrize(
740-
"splitter, state_splitter, state_rpn, expected",
768+
"splitter, state_splitter, state_rpn, expected, expected_ind",
741769
[
742770
(
743771
("a", "b"),
744772
("NA.a", "NA.b"),
745773
["NA.a", "NA.b", "."],
746774
[({"NA.a": 3, "NA.b": 10}, 13), ({"NA.a": 5, "NA.b": 20}, 25)],
775+
[({"NA.a": 0, "NA.b": 0}, 13), ({"NA.a": 1, "NA.b": 1}, 25)],
747776
),
748777
(
749778
["a", "b"],
@@ -755,11 +784,19 @@ def test_task_state_1a(plugin):
755784
({"NA.a": 5, "NA.b": 10}, 15),
756785
({"NA.a": 5, "NA.b": 20}, 25),
757786
],
787+
[
788+
({"NA.a": 0, "NA.b": 0}, 13),
789+
({"NA.a": 0, "NA.b": 1}, 23),
790+
({"NA.a": 1, "NA.b": 0}, 15),
791+
({"NA.a": 1, "NA.b": 1}, 25),
792+
],
758793
),
759794
],
760795
)
761796
@pytest.mark.parametrize("plugin", Plugins)
762-
def test_task_state_2(plugin, splitter, state_splitter, state_rpn, expected):
797+
def test_task_state_2(
798+
plugin, splitter, state_splitter, state_rpn, expected, expected_ind
799+
):
763800
""" Tasks with two inputs and a splitter (no combiner)"""
764801
nn = fun_addvar(name="NA").split(splitter=splitter, a=[3, 5], b=[10, 20])
765802

@@ -777,6 +814,21 @@ def test_task_state_2(plugin, splitter, state_splitter, state_rpn, expected):
777814
results = nn.result()
778815
for i, res in enumerate(expected):
779816
assert results[i].output.out == res[1]
817+
818+
# checking the return_inputs option, either return_inputs is True or "val",
819+
# it should give values of inputs that corresponds to the specific element
820+
results_verb = nn.result(return_inputs=True)
821+
results_verb_val = nn.result(return_inputs="val")
822+
for i, res in enumerate(expected):
823+
assert (results_verb[i][0], results_verb[i][1].output.out) == res
824+
assert (results_verb_val[i][0], results_verb_val[i][1].output.out) == res
825+
826+
# checking the return_inputs option return_inputs="ind"
827+
# it should give indices of inputs (instead of values) for each element
828+
results_verb_ind = nn.result(return_inputs="ind")
829+
for i, res in enumerate(expected_ind):
830+
assert (results_verb_ind[i][0], results_verb_ind[i][1].output.out) == res
831+
780832
# checking the output_dir
781833
assert nn.output_dir
782834
for odir in nn.output_dir:
@@ -982,11 +1034,25 @@ def test_task_state_comb_1(plugin):
9821034

9831035
# checking the results
9841036
results = nn.result()
985-
9861037
# fully combined (no nested list)
9871038
combined_results = [res.output.out for res in results]
988-
9891039
assert combined_results == [5, 7]
1040+
1041+
expected = [({"NA.a": 3}, 5), ({"NA.a": 5}, 7)]
1042+
expected_ind = [({"NA.a": 0}, 5), ({"NA.a": 1}, 7)]
1043+
# checking the return_inputs option, either return_inputs is True or "val",
1044+
# it should give values of inputs that corresponds to the specific element
1045+
results_verb = nn.result(return_inputs=True)
1046+
results_verb_val = nn.result(return_inputs="val")
1047+
for i, res in enumerate(expected):
1048+
assert (results_verb[i][0], results_verb[i][1].output.out) == res
1049+
assert (results_verb_val[i][0], results_verb_val[i][1].output.out) == res
1050+
# checking the return_inputs option return_inputs="ind"
1051+
# it should give indices of inputs (instead of values) for each element
1052+
results_verb_ind = nn.result(return_inputs="ind")
1053+
for i, res in enumerate(expected_ind):
1054+
assert (results_verb_ind[i][0], results_verb_ind[i][1].output.out) == res
1055+
9901056
# checking the output_dir
9911057
assert nn.output_dir
9921058
for odir in nn.output_dir:
@@ -995,7 +1061,7 @@ def test_task_state_comb_1(plugin):
9951061

9961062
@pytest.mark.parametrize(
9971063
"splitter, combiner, state_splitter, state_rpn, state_combiner, state_combiner_all, "
998-
"state_splitter_final, state_rpn_final, expected",
1064+
"state_splitter_final, state_rpn_final, expected, expected_val",
9991065
[
10001066
(
10011067
("a", "b"),
@@ -1006,7 +1072,8 @@ def test_task_state_comb_1(plugin):
10061072
["NA.a", "NA.b"],
10071073
None,
10081074
[],
1009-
[({}, [13, 25])],
1075+
[13, 25],
1076+
[({"NA.a": 3, "NA.b": 10}, 13), ({"NA.a": 5, "NA.b": 20}, 25)],
10101077
),
10111078
(
10121079
("a", "b"),
@@ -1017,7 +1084,8 @@ def test_task_state_comb_1(plugin):
10171084
["NA.a", "NA.b"],
10181085
None,
10191086
[],
1020-
[({}, [13, 25])],
1087+
[13, 25],
1088+
[({"NA.a": 3, "NA.b": 10}, 13), ({"NA.a": 5, "NA.b": 20}, 25)],
10211089
),
10221090
(
10231091
["a", "b"],
@@ -1028,7 +1096,11 @@ def test_task_state_comb_1(plugin):
10281096
["NA.a"],
10291097
"NA.b",
10301098
["NA.b"],
1031-
[({"NA.b": 10}, [13, 15]), ({"NA.b": 20}, [23, 25])],
1099+
[[13, 15], [23, 25]],
1100+
[
1101+
[({"NA.a": 3, "NA.b": 10}, 13), ({"NA.a": 5, "NA.b": 10}, 15)],
1102+
[({"NA.a": 3, "NA.b": 20}, 23), ({"NA.a": 5, "NA.b": 20}, 25)],
1103+
],
10321104
),
10331105
(
10341106
["a", "b"],
@@ -1039,7 +1111,11 @@ def test_task_state_comb_1(plugin):
10391111
["NA.b"],
10401112
"NA.a",
10411113
["NA.a"],
1042-
[({"NA.a": 3}, [13, 23]), ({"NA.a": 5}, [15, 25])],
1114+
[[13, 23], [15, 25]],
1115+
[
1116+
[({"NA.a": 3, "NA.b": 10}, 13), ({"NA.a": 3, "NA.b": 20}, 23)],
1117+
[({"NA.a": 5, "NA.b": 10}, 15), ({"NA.a": 5, "NA.b": 20}, 25)],
1118+
],
10431119
),
10441120
(
10451121
["a", "b"],
@@ -1050,7 +1126,13 @@ def test_task_state_comb_1(plugin):
10501126
["NA.a", "NA.b"],
10511127
None,
10521128
[],
1053-
[({}, [13, 23, 15, 25])],
1129+
[13, 23, 15, 25],
1130+
[
1131+
({"NA.a": 3, "NA.b": 10}, 13),
1132+
({"NA.a": 3, "NA.b": 20}, 23),
1133+
({"NA.a": 5, "NA.b": 10}, 15),
1134+
({"NA.a": 5, "NA.b": 20}, 25),
1135+
],
10541136
),
10551137
],
10561138
)
@@ -1066,6 +1148,7 @@ def test_task_state_comb_2(
10661148
state_splitter_final,
10671149
state_rpn_final,
10681150
expected,
1151+
expected_val,
10691152
):
10701153
""" Tasks with scalar and outer splitters and partial or full combiners"""
10711154
nn = (
@@ -1080,19 +1163,32 @@ def test_task_state_comb_2(
10801163
assert nn.state.splitter_rpn == state_rpn
10811164
assert nn.state.combiner == state_combiner
10821165

1166+
with Submitter(plugin=plugin) as sub:
1167+
sub(nn)
1168+
10831169
assert nn.state.splitter_final == state_splitter_final
10841170
assert nn.state.splitter_rpn_final == state_rpn_final
10851171
assert set(nn.state.right_combiner_all) == set(state_combiner_all)
10861172

1087-
with Submitter(plugin=plugin) as sub:
1088-
sub(nn)
1089-
10901173
# checking the results
10911174
results = nn.result()
1175+
# checking the return_inputs option, either return_inputs is True or "val",
1176+
# it should give values of inputs that corresponds to the specific element
1177+
results_verb = nn.result(return_inputs=True)
1178+
1179+
if nn.state.splitter_rpn_final:
1180+
for i, res in enumerate(expected):
1181+
assert [res.output.out for res in results[i]] == res
1182+
# results_verb
1183+
for i, res_l in enumerate(expected_val):
1184+
for j, res in enumerate(res_l):
1185+
assert (results_verb[i][j][0], results_verb[i][j][1].output.out) == res
1186+
# if the combiner is full expected is "a flat list"
1187+
else:
1188+
assert [res.output.out for res in results] == expected
1189+
for i, res in enumerate(expected_val):
1190+
assert (results_verb[i][0], results_verb[i][1].output.out) == res
10921191

1093-
combined_results = [[res.output.out for res in res_l] for res_l in results]
1094-
for i, res in enumerate(expected):
1095-
assert combined_results[i] == res[1]
10961192
# checking the output_dir
10971193
assert nn.output_dir
10981194
for odir in nn.output_dir:
@@ -1130,7 +1226,7 @@ def test_task_state_comb_singl_1(plugin):
11301226

11311227

11321228
@pytest.mark.parametrize("plugin", Plugins)
1133-
def test_task_state_comb_2(plugin):
1229+
def test_task_state_comb_3(plugin):
11341230
""" task with the simplest splitter, the input is an empty list"""
11351231
nn = fun_addtwo(name="NA").split(splitter="a", a=[]).combine(combiner=["a"])
11361232

0 commit comments

Comments
 (0)