@@ -827,7 +827,7 @@ async def copy_to_table(self, table_name, *, source,
827
827
delimiter = None , null = None , header = None ,
828
828
quote = None , escape = None , force_quote = None ,
829
829
force_not_null = None , force_null = None ,
830
- encoding = None ):
830
+ encoding = None , where = None ):
831
831
"""Copy data to the specified table.
832
832
833
833
:param str table_name:
@@ -846,6 +846,15 @@ async def copy_to_table(self, table_name, *, source,
846
846
:param str schema_name:
847
847
An optional schema name to qualify the table.
848
848
849
+ :param str where:
850
+ An optional condition used to filter rows when copying.
851
+
852
+ .. note::
853
+
854
+ Usage of this parameter requires support for the
855
+ ``COPY FROM ... WHERE`` syntax, introduced in
856
+ PostgreSQL version 12.
857
+
849
858
:param float timeout:
850
859
Optional timeout value in seconds.
851
860
@@ -873,6 +882,9 @@ async def copy_to_table(self, table_name, *, source,
873
882
https://www.postgresql.org/docs/current/static/sql-copy.html
874
883
875
884
.. versionadded:: 0.11.0
885
+
886
+ .. versionadded:: 0.27.0
887
+ Added ``where`` parameter.
876
888
"""
877
889
tabname = utils ._quote_ident (table_name )
878
890
if schema_name :
@@ -884,21 +896,22 @@ async def copy_to_table(self, table_name, *, source,
884
896
else :
885
897
cols = ''
886
898
899
+ cond = self ._format_copy_where (where )
887
900
opts = self ._format_copy_opts (
888
901
format = format , oids = oids , freeze = freeze , delimiter = delimiter ,
889
902
null = null , header = header , quote = quote , escape = escape ,
890
903
force_not_null = force_not_null , force_null = force_null ,
891
904
encoding = encoding
892
905
)
893
906
894
- copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}' .format (
895
- tab = tabname , cols = cols , opts = opts )
907
+ copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond} ' .format (
908
+ tab = tabname , cols = cols , opts = opts , cond = cond )
896
909
897
910
return await self ._copy_in (copy_stmt , source , timeout )
898
911
899
912
async def copy_records_to_table (self , table_name , * , records ,
900
913
columns = None , schema_name = None ,
901
- timeout = None ):
914
+ timeout = None , where = None ):
902
915
"""Copy a list of records to the specified table using binary COPY.
903
916
904
917
:param str table_name:
@@ -915,6 +928,16 @@ async def copy_records_to_table(self, table_name, *, records,
915
928
:param str schema_name:
916
929
An optional schema name to qualify the table.
917
930
931
+ :param str where:
932
+ An optional condition used to filter rows when copying.
933
+
934
+ .. note::
935
+
936
+ Usage of this parameter requires support for the
937
+ ``COPY FROM ... WHERE`` syntax, introduced in
938
+ PostgreSQL version 12.
939
+
940
+
918
941
:param float timeout:
919
942
Optional timeout value in seconds.
920
943
@@ -959,6 +982,9 @@ async def copy_records_to_table(self, table_name, *, records,
959
982
960
983
.. versionchanged:: 0.24.0
961
984
The ``records`` argument may be an asynchronous iterable.
985
+
986
+ .. versionadded:: 0.27.0
987
+ Added ``where`` parameter.
962
988
"""
963
989
tabname = utils ._quote_ident (table_name )
964
990
if schema_name :
@@ -976,14 +1002,27 @@ async def copy_records_to_table(self, table_name, *, records,
976
1002
977
1003
intro_ps = await self ._prepare (intro_query , use_cache = True )
978
1004
1005
+ cond = self ._format_copy_where (where )
979
1006
opts = '(FORMAT binary)'
980
1007
981
- copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}' .format (
982
- tab = tabname , cols = cols , opts = opts )
1008
+ copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond} ' .format (
1009
+ tab = tabname , cols = cols , opts = opts , cond = cond )
983
1010
984
1011
return await self ._protocol .copy_in (
985
1012
copy_stmt , None , None , records , intro_ps ._state , timeout )
986
1013
1014
+ def _format_copy_where (self , where ):
1015
+ if where and not self ._server_caps .sql_copy_from_where :
1016
+ raise exceptions .UnsupportedServerFeatureError (
1017
+ 'the `where` parameter requires PostgreSQL 12 or later' )
1018
+
1019
+ if where :
1020
+ where_clause = 'WHERE ' + where
1021
+ else :
1022
+ where_clause = ''
1023
+
1024
+ return where_clause
1025
+
987
1026
def _format_copy_opts (self , * , format = None , oids = None , freeze = None ,
988
1027
delimiter = None , null = None , header = None , quote = None ,
989
1028
escape = None , force_quote = None , force_not_null = None ,
@@ -2308,7 +2347,7 @@ class _ConnectionProxy:
2308
2347
ServerCapabilities = collections .namedtuple (
2309
2348
'ServerCapabilities' ,
2310
2349
['advisory_locks' , 'notifications' , 'plpgsql' , 'sql_reset' ,
2311
- 'sql_close_all' ])
2350
+ 'sql_close_all' , 'sql_copy_from_where' ])
2312
2351
ServerCapabilities .__doc__ = 'PostgreSQL server capabilities.'
2313
2352
2314
2353
@@ -2320,27 +2359,31 @@ def _detect_server_capabilities(server_version, connection_settings):
2320
2359
plpgsql = False
2321
2360
sql_reset = True
2322
2361
sql_close_all = False
2362
+ sql_copy_from_where = False
2323
2363
elif hasattr (connection_settings , 'crdb_version' ):
2324
2364
# CockroachDB detected.
2325
2365
advisory_locks = False
2326
2366
notifications = False
2327
2367
plpgsql = False
2328
2368
sql_reset = False
2329
2369
sql_close_all = False
2370
+ sql_copy_from_where = False
2330
2371
elif hasattr (connection_settings , 'crate_version' ):
2331
2372
# CrateDB detected.
2332
2373
advisory_locks = False
2333
2374
notifications = False
2334
2375
plpgsql = False
2335
2376
sql_reset = False
2336
2377
sql_close_all = False
2378
+ sql_copy_from_where = False
2337
2379
else :
2338
2380
# Standard PostgreSQL server assumed.
2339
2381
advisory_locks = True
2340
2382
notifications = True
2341
2383
plpgsql = True
2342
2384
sql_reset = True
2343
2385
sql_close_all = True
2386
+ sql_copy_from_where = server_version .major >= 12
2344
2387
2345
2388
return ServerCapabilities (
2346
2389
advisory_locks = advisory_locks ,
0 commit comments