diff --git a/prompt_toolkit/shortcuts/dialogs.py b/prompt_toolkit/shortcuts/dialogs.py index fa8d0554d..742b7f4c7 100644 --- a/prompt_toolkit/shortcuts/dialogs.py +++ b/prompt_toolkit/shortcuts/dialogs.py @@ -111,14 +111,14 @@ def message_dialog(title='', text='', ok_text='Ok', style=None, async_=False): def radiolist_dialog(title='', text='', ok_text='Ok', cancel_text='Cancel', - values=None, style=None, async_=False): + values=None, style=None, async_=False, default=0): """ Display a simple message box and wait until the user presses enter. """ def ok_handler(): get_app().exit(result=radio_list.current_value) - radio_list = RadioList(values) + radio_list = RadioList(values, default) dialog = Dialog( title=title, diff --git a/prompt_toolkit/widgets/base.py b/prompt_toolkit/widgets/base.py index a90709ba0..eeee7aeb5 100644 --- a/prompt_toolkit/widgets/base.py +++ b/prompt_toolkit/widgets/base.py @@ -497,16 +497,20 @@ class RadioList(object): List of radio buttons. Only one can be checked at the same time. :param values: List of (value, label) tuples. + :param default: Default index to select, defaults to 0 """ - def __init__(self, values): + def __init__(self, values, default=0): assert isinstance(values, list) assert len(values) > 0 assert all(isinstance(i, tuple) and len(i) == 2 for i in values) + assert isinstance(default, int) + assert 0 <= default < len(values) self.values = values - self.current_value = values[0][0] - self._selected_index = 0 + + self.current_value = values[default][0] + self._selected_index = default # Key bindings. kb = KeyBindings() diff --git a/tests/test_radiolist.py b/tests/test_radiolist.py new file mode 100644 index 000000000..4b414ff4b --- /dev/null +++ b/tests/test_radiolist.py @@ -0,0 +1,96 @@ +from __future__ import unicode_literals + +from prompt_toolkit.widgets import RadioList + +import pytest + +def test_initial(): + values = [ + (1, 'some_text'), + ('fizz', 'some_more_text'), + ({'foo':'bar'}, 'even_more_text') + ] + radiolist = RadioList(values) + assert radiolist.current_value == 1 + assert radiolist._selected_index == 0 + +def test_default(): + values = [ + (1, 'some_text'), + ('fizz', 'some_more_text'), + ({'foo':'bar'}, 'even_more_text') + ] + radiolist = RadioList(values, 1) + assert radiolist.current_value == 'fizz' + assert radiolist._selected_index == 1 + +def test_bad_params(): + with pytest.raises(AssertionError): + radiolist = RadioList([]) + + with pytest.raises(AssertionError): + radiolist = RadioList(None) + + with pytest.raises(AssertionError): + values = ( + (1, 'some_text'), + ('fizz', 'some_more_text'), + ({'foo':'bar'}, 'even_more_text') + ) + radiolist = RadioList(values) + + with pytest.raises(AssertionError): + values = [ + (1, 'some_text', 'whoops'), + ('fizz', 'some_more_text'), + ({'foo':'bar'}, 'even_more_text') + ] + radiolist = RadioList(values) + + with pytest.raises(AssertionError): + values = [ + (1, 'some_text'), + ('fizz', ), + ({'foo':'bar'}, 'even_more_text') + ] + radiolist = RadioList(values) + + with pytest.raises(AssertionError): + values = [ + (1, 'some_text'), + ('fizz', 'some_more_text'), + [{'foo':'bar'}, 'even_more_text'] + ] + radiolist = RadioList(values) + + with pytest.raises(AssertionError): + values = [ + (1, 'some_text'), + ('fizz', 'some_more_text'), + ({'foo':'bar'}, 'even_more_text') + ] + radiolist = RadioList(values, 3) + + with pytest.raises(AssertionError): + values = [ + (1, 'some_text'), + ('fizz', 'some_more_text'), + ({'foo':'bar'}, 'even_more_text') + ] + radiolist = RadioList(values, -1) + + with pytest.raises(AssertionError): + values = [ + (1, 'some_text'), + ('fizz', 'some_more_text'), + ({'foo':'bar'}, 'even_more_text') + ] + radiolist = RadioList(values, None) + + with pytest.raises(AssertionError): + values = [ + (1, 'some_text'), + ('fizz', 'some_more_text'), + ({'foo':'bar'}, 'even_more_text') + ] + radiolist = RadioList(values, 'whoops')