Skip to content

Commit 29f6a1e

Browse files
[7.x] Forward auth parameters to APIs in scan helpers
Co-authored-by: Seth Michael Larson <seth.larson@elastic.co>
1 parent 0bfac34 commit 29f6a1e

File tree

4 files changed

+198
-0
lines changed

4 files changed

+198
-0
lines changed

elasticsearch/_async/helpers.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,20 @@ async def async_scan(
338338
query = query.copy() if query else {}
339339
query["sort"] = "_doc"
340340

341+
# Grab options that should be propagated to every
342+
# API call within this helper instead of just 'search()'
343+
transport_kwargs = {}
344+
for key in ("headers", "api_key", "http_auth"):
345+
if key in kwargs:
346+
transport_kwargs[key] = kwargs[key]
347+
348+
# If the user is using 'scroll_kwargs' we want
349+
# to propagate there too, but to not break backwards
350+
# compatibility we'll not override anything already given.
351+
if scroll_kwargs is not None and transport_kwargs:
352+
for key, val in transport_kwargs.items():
353+
scroll_kwargs.setdefault(key, val)
354+
341355
# initial search
342356
resp = await client.search(
343357
body=query, scroll=scroll, size=size, request_timeout=request_timeout, **kwargs
@@ -382,6 +396,7 @@ async def async_scan(
382396
if scroll_id and clear_scroll:
383397
await client.clear_scroll(
384398
body={"scroll_id": [scroll_id]},
399+
**transport_kwargs,
385400
ignore=(404,),
386401
params={"__elastic_client_meta": (("h", "s"),)},
387402
)

elasticsearch/helpers/actions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,20 @@ def scan(
550550
query = query.copy() if query else {}
551551
query["sort"] = "_doc"
552552

553+
# Grab options that should be propagated to every
554+
# API call within this helper instead of just 'search()'
555+
transport_kwargs = {}
556+
for key in ("headers", "api_key", "http_auth"):
557+
if key in kwargs:
558+
transport_kwargs[key] = kwargs[key]
559+
560+
# If the user is using 'scroll_kwargs' we want
561+
# to propagate there too, but to not break backwards
562+
# compatibility we'll not override anything already given.
563+
if scroll_kwargs is not None and transport_kwargs:
564+
for key, val in transport_kwargs.items():
565+
scroll_kwargs.setdefault(key, val)
566+
553567
# initial search
554568
resp = client.search(
555569
body=query, scroll=scroll, size=size, request_timeout=request_timeout, **kwargs
@@ -596,6 +610,7 @@ def scan(
596610
body={"scroll_id": [scroll_id]},
597611
ignore=(404,),
598612
params={"__elastic_client_meta": (("h", "s"),)},
613+
**transport_kwargs
599614
)
600615

601616

test_elasticsearch/test_async/test_server/test_helpers.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ class AsyncMock(MagicMock):
3434
async def __call__(self, *args, **kwargs):
3535
return super(AsyncMock, self).__call__(*args, **kwargs)
3636

37+
def __await__(self):
38+
return self().__await__()
39+
3740

3841
class FailingBulkClient(object):
3942
def __init__(
@@ -419,6 +422,9 @@ def __init__(self, resp):
419422
async def __call__(self, *args, **kwargs):
420423
return self.resp
421424

425+
def __await__(self):
426+
return self().__await__()
427+
422428

423429
@pytest.fixture(scope="function")
424430
async def scan_teardown(async_client):
@@ -645,6 +651,105 @@ async def test_clear_scroll(self, async_client, scan_teardown):
645651
]
646652
spy.assert_not_called()
647653

654+
@pytest.mark.parametrize(
655+
"kwargs",
656+
[
657+
{"api_key": ("name", "value")},
658+
{"http_auth": ("username", "password")},
659+
{"headers": {"custom", "header"}},
660+
],
661+
)
662+
async def test_scan_auth_kwargs_forwarded(
663+
self, async_client, scan_teardown, kwargs
664+
):
665+
((key, val),) = kwargs.items()
666+
667+
with patch.object(
668+
async_client,
669+
"search",
670+
return_value=MockResponse(
671+
{
672+
"_scroll_id": "scroll_id",
673+
"_shards": {"successful": 5, "total": 5, "skipped": 0},
674+
"hits": {"hits": [{"search_data": 1}]},
675+
}
676+
),
677+
) as search_mock:
678+
with patch.object(
679+
async_client,
680+
"scroll",
681+
return_value=MockResponse(
682+
{
683+
"_scroll_id": "scroll_id",
684+
"_shards": {"successful": 5, "total": 5, "skipped": 0},
685+
"hits": {"hits": []},
686+
}
687+
),
688+
) as scroll_mock:
689+
with patch.object(
690+
async_client, "clear_scroll", return_value=MockResponse({})
691+
) as clear_mock:
692+
data = [
693+
x
694+
async for x in helpers.async_scan(
695+
async_client, index="test_index", **kwargs
696+
)
697+
]
698+
699+
assert data == [{"search_data": 1}]
700+
701+
for api_mock in (search_mock, scroll_mock, clear_mock):
702+
assert api_mock.call_args[1][key] == val
703+
704+
async def test_scan_auth_kwargs_favor_scroll_kwargs_option(
705+
self, async_client, scan_teardown
706+
):
707+
with patch.object(
708+
async_client,
709+
"search",
710+
return_value=MockResponse(
711+
{
712+
"_scroll_id": "scroll_id",
713+
"_shards": {"successful": 5, "total": 5, "skipped": 0},
714+
"hits": {"hits": [{"search_data": 1}]},
715+
}
716+
),
717+
):
718+
with patch.object(
719+
async_client,
720+
"scroll",
721+
return_value=MockResponse(
722+
{
723+
"_scroll_id": "scroll_id",
724+
"_shards": {"successful": 5, "total": 5, "skipped": 0},
725+
"hits": {"hits": []},
726+
}
727+
),
728+
):
729+
with patch.object(
730+
async_client, "clear_scroll", return_value=MockResponse({})
731+
):
732+
data = [
733+
x
734+
async for x in helpers.async_scan(
735+
async_client,
736+
index="test_index",
737+
headers={"not scroll": "kwargs"},
738+
scroll_kwargs={
739+
"headers": {"scroll": "kwargs"},
740+
"sort": "asc",
741+
},
742+
)
743+
]
744+
745+
assert data == [{"search_data": 1}]
746+
747+
# Assert that we see 'scroll_kwargs' options used instead of 'kwargs'
748+
assert async_client.scroll.call_args[1]["headers"] == {
749+
"scroll": "kwargs"
750+
}
751+
assert async_client.scroll.call_args[1]["sort"] == "asc"
752+
648753

649754
@pytest.fixture(scope="function")
650755
async def reindex_setup(async_client):

test_elasticsearch/test_server/test_helpers.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,69 @@ def test_no_scroll_id_fast_route(self):
478478
client_mock.scroll.assert_not_called()
479479
client_mock.clear_scroll.assert_not_called()
480480

481+
def test_scan_auth_kwargs_forwarded(self):
482+
for key, val in {
483+
"api_key": ("name", "value"),
484+
"http_auth": ("username", "password"),
485+
"headers": {"custom": "header"},
486+
}.items():
487+
with patch.object(self, "client") as client_mock:
488+
client_mock.search.return_value = {
489+
"_scroll_id": "scroll_id",
490+
"_shards": {"successful": 5, "total": 5, "skipped": 0},
491+
"hits": {"hits": [{"search_data": 1}]},
492+
}
493+
client_mock.scroll.return_value = {
494+
"_scroll_id": "scroll_id",
495+
"_shards": {"successful": 5, "total": 5, "skipped": 0},
496+
"hits": {"hits": []},
497+
}
498+
client_mock.clear_scroll.return_value = {}
499+
500+
data = list(helpers.scan(self.client, index="test_index", **{key: val}))
501+
502+
self.assertEqual(data, [{"search_data": 1}])
503+
504+
# Assert that 'search', 'scroll' and 'clear_scroll' all
505+
# received the extra kwarg related to authentication.
506+
for api_mock in (
507+
client_mock.search,
508+
client_mock.scroll,
509+
client_mock.clear_scroll,
510+
):
511+
self.assertEqual(api_mock.call_args[1][key], val)
512+
513+
def test_scan_auth_kwargs_favor_scroll_kwargs_option(self):
514+
with patch.object(self, "client") as client_mock:
515+
client_mock.search.return_value = {
516+
"_scroll_id": "scroll_id",
517+
"_shards": {"successful": 5, "total": 5, "skipped": 0},
518+
"hits": {"hits": [{"search_data": 1}]},
519+
}
520+
client_mock.scroll.return_value = {
521+
"_scroll_id": "scroll_id",
522+
"_shards": {"successful": 5, "total": 5, "skipped": 0},
523+
"hits": {"hits": []},
524+
}
525+
client_mock.clear_scroll.return_value = {}
526+
527+
data = list(
528+
helpers.scan(
529+
self.client,
530+
index="test_index",
531+
scroll_kwargs={"headers": {"scroll": "kwargs"}, "sort": "asc"},
532+
headers={"not scroll": "kwargs"},
533+
)
534+
)
535+
536+
self.assertEqual(data, [{"search_data": 1}])
537+
538+
# Assert that we see 'scroll_kwargs' options used instead of 'kwargs'
539+
self.assertEqual(
540+
client_mock.scroll.call_args[1]["headers"], {"scroll": "kwargs"}
541+
)
542+
self.assertEqual(client_mock.scroll.call_args[1]["sort"], "asc")
543+
481544
@patch("elasticsearch.helpers.actions.logger")
482545
def test_logger(self, logger_mock):
483546
bulk = []

0 commit comments

Comments
 (0)