Skip to content

Commit af1b6ba

Browse files
committed
Add support for WHERE clause in copy_to methods
1 parent 9825bbb commit af1b6ba

File tree

3 files changed

+63
-12
lines changed

3 files changed

+63
-12
lines changed

asyncpg/connection.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,7 @@ async def copy_to_table(self, table_name, *, source,
827827
delimiter=None, null=None, header=None,
828828
quote=None, escape=None, force_quote=None,
829829
force_not_null=None, force_null=None,
830-
encoding=None):
830+
encoding=None, where=None):
831831
"""Copy data to the specified table.
832832
833833
:param str table_name:
@@ -846,6 +846,15 @@ async def copy_to_table(self, table_name, *, source,
846846
:param str schema_name:
847847
An optional schema name to qualify the table.
848848
849+
:param str where:
850+
An optional condition used to filter rows when copying.
851+
852+
.. note::
853+
854+
Usage of this parameter requires support for the
855+
``COPY FROM ... WHERE`` syntax, introduced in
856+
PostgreSQL version 12.
857+
849858
:param float timeout:
850859
Optional timeout value in seconds.
851860
@@ -873,6 +882,9 @@ async def copy_to_table(self, table_name, *, source,
873882
https://www.postgresql.org/docs/current/static/sql-copy.html
874883
875884
.. versionadded:: 0.11.0
885+
886+
.. versionadded:: 0.27.0
887+
Added ``where`` parameter.
876888
"""
877889
tabname = utils._quote_ident(table_name)
878890
if schema_name:
@@ -884,21 +896,22 @@ async def copy_to_table(self, table_name, *, source,
884896
else:
885897
cols = ''
886898

899+
cond = self._format_copy_where(where)
887900
opts = self._format_copy_opts(
888901
format=format, oids=oids, freeze=freeze, delimiter=delimiter,
889902
null=null, header=header, quote=quote, escape=escape,
890903
force_not_null=force_not_null, force_null=force_null,
891904
encoding=encoding
892905
)
893906

894-
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
895-
tab=tabname, cols=cols, opts=opts)
907+
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond}'.format(
908+
tab=tabname, cols=cols, opts=opts, cond=cond)
896909

897910
return await self._copy_in(copy_stmt, source, timeout)
898911

899912
async def copy_records_to_table(self, table_name, *, records,
900913
columns=None, schema_name=None,
901-
timeout=None):
914+
timeout=None, where=None):
902915
"""Copy a list of records to the specified table using binary COPY.
903916
904917
:param str table_name:
@@ -915,6 +928,16 @@ async def copy_records_to_table(self, table_name, *, records,
915928
:param str schema_name:
916929
An optional schema name to qualify the table.
917930
931+
:param str where:
932+
An optional condition used to filter rows when copying.
933+
934+
.. note::
935+
936+
Usage of this parameter requires support for the
937+
``COPY FROM ... WHERE`` syntax, introduced in
938+
PostgreSQL version 12.
939+
940+
918941
:param float timeout:
919942
Optional timeout value in seconds.
920943
@@ -959,6 +982,9 @@ async def copy_records_to_table(self, table_name, *, records,
959982
960983
.. versionchanged:: 0.24.0
961984
The ``records`` argument may be an asynchronous iterable.
985+
986+
.. versionadded:: 0.27.0
987+
Added ``where`` parameter.
962988
"""
963989
tabname = utils._quote_ident(table_name)
964990
if schema_name:
@@ -976,14 +1002,27 @@ async def copy_records_to_table(self, table_name, *, records,
9761002

9771003
intro_ps = await self._prepare(intro_query, use_cache=True)
9781004

1005+
cond = self._format_copy_where(where)
9791006
opts = '(FORMAT binary)'
9801007

981-
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
982-
tab=tabname, cols=cols, opts=opts)
1008+
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond}'.format(
1009+
tab=tabname, cols=cols, opts=opts, cond=cond)
9831010

9841011
return await self._protocol.copy_in(
9851012
copy_stmt, None, None, records, intro_ps._state, timeout)
9861013

1014+
def _format_copy_where(self, where):
1015+
if where and not self._server_caps.sql_copy_from_where:
1016+
raise exceptions.UnsupportedServerFeatureError(
1017+
'the `where` parameter requires PostgreSQL 12 or later')
1018+
1019+
if where:
1020+
where_clause = 'WHERE ' + where
1021+
else:
1022+
where_clause = ''
1023+
1024+
return where_clause
1025+
9871026
def _format_copy_opts(self, *, format=None, oids=None, freeze=None,
9881027
delimiter=None, null=None, header=None, quote=None,
9891028
escape=None, force_quote=None, force_not_null=None,
@@ -2308,7 +2347,7 @@ class _ConnectionProxy:
23082347
ServerCapabilities = collections.namedtuple(
23092348
'ServerCapabilities',
23102349
['advisory_locks', 'notifications', 'plpgsql', 'sql_reset',
2311-
'sql_close_all'])
2350+
'sql_close_all', 'sql_copy_from_where'])
23122351
ServerCapabilities.__doc__ = 'PostgreSQL server capabilities.'
23132352

23142353

@@ -2320,27 +2359,31 @@ def _detect_server_capabilities(server_version, connection_settings):
23202359
plpgsql = False
23212360
sql_reset = True
23222361
sql_close_all = False
2362+
sql_copy_from_where = False
23232363
elif hasattr(connection_settings, 'crdb_version'):
23242364
# CockroachDB detected.
23252365
advisory_locks = False
23262366
notifications = False
23272367
plpgsql = False
23282368
sql_reset = False
23292369
sql_close_all = False
2370+
sql_copy_from_where = False
23302371
elif hasattr(connection_settings, 'crate_version'):
23312372
# CrateDB detected.
23322373
advisory_locks = False
23332374
notifications = False
23342375
plpgsql = False
23352376
sql_reset = False
23362377
sql_close_all = False
2378+
sql_copy_from_where = False
23372379
else:
23382380
# Standard PostgreSQL server assumed.
23392381
advisory_locks = True
23402382
notifications = True
23412383
plpgsql = True
23422384
sql_reset = True
23432385
sql_close_all = True
2386+
sql_copy_from_where = server_version.major >= 12
23442387

23452388
return ServerCapabilities(
23462389
advisory_locks=advisory_locks,

asyncpg/exceptions/_base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
1414
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage',
1515
'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError',
16-
'UnsupportedClientFeatureError')
16+
'UnsupportedClientFeatureError', 'UnsupportedServerFeatureError')
1717

1818

1919
def _is_asyncpg_class(cls):
@@ -228,6 +228,10 @@ class UnsupportedClientFeatureError(InterfaceError):
228228
"""Requested feature is unsupported by asyncpg."""
229229

230230

231+
class UnsupportedServerFeatureError(InterfaceError):
232+
"""Requested feature is unsupported by PostgreSQL server."""
233+
234+
231235
class InterfaceWarning(InterfaceMessage, UserWarning):
232236
"""A warning caused by an improper use of asyncpg API."""
233237

asyncpg/pool.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,8 @@ async def copy_to_table(
736736
force_quote=None,
737737
force_not_null=None,
738738
force_null=None,
739-
encoding=None
739+
encoding=None,
740+
where=None
740741
):
741742
"""Copy data to the specified table.
742743
@@ -765,7 +766,8 @@ async def copy_to_table(
765766
force_quote=force_quote,
766767
force_not_null=force_not_null,
767768
force_null=force_null,
768-
encoding=encoding
769+
encoding=encoding,
770+
where=where
769771
)
770772

771773
async def copy_records_to_table(
@@ -775,7 +777,8 @@ async def copy_records_to_table(
775777
records,
776778
columns=None,
777779
schema_name=None,
778-
timeout=None
780+
timeout=None,
781+
where=None
779782
):
780783
"""Copy a list of records to the specified table using binary COPY.
781784
@@ -792,7 +795,8 @@ async def copy_records_to_table(
792795
records=records,
793796
columns=columns,
794797
schema_name=schema_name,
795-
timeout=timeout
798+
timeout=timeout,
799+
where=where
796800
)
797801

798802
def acquire(self, *, timeout=None):

0 commit comments

Comments
 (0)