@@ -832,7 +832,7 @@ async def copy_to_table(self, table_name, *, source,
832
832
delimiter = None , null = None , header = None ,
833
833
quote = None , escape = None , force_quote = None ,
834
834
force_not_null = None , force_null = None ,
835
- encoding = None ):
835
+ encoding = None , where = None ):
836
836
"""Copy data to the specified table.
837
837
838
838
:param str table_name:
@@ -851,6 +851,15 @@ async def copy_to_table(self, table_name, *, source,
851
851
:param str schema_name:
852
852
An optional schema name to qualify the table.
853
853
854
+ :param str where:
855
+ An optional condition used to filter rows when copying.
856
+
857
+ .. note::
858
+
859
+ Usage of this parameter requires support for the
860
+ ``COPY FROM ... WHERE`` syntax, introduced in
861
+ PostgreSQL version 12.
862
+
854
863
:param float timeout:
855
864
Optional timeout value in seconds.
856
865
@@ -878,6 +887,9 @@ async def copy_to_table(self, table_name, *, source,
878
887
https://www.postgresql.org/docs/current/static/sql-copy.html
879
888
880
889
.. versionadded:: 0.11.0
890
+
891
+ .. versionadded:: 0.27.0
892
+ Added ``where`` parameter.
881
893
"""
882
894
tabname = utils ._quote_ident (table_name )
883
895
if schema_name :
@@ -889,21 +901,22 @@ async def copy_to_table(self, table_name, *, source,
889
901
else :
890
902
cols = ''
891
903
904
+ cond = self ._format_copy_where (where )
892
905
opts = self ._format_copy_opts (
893
906
format = format , oids = oids , freeze = freeze , delimiter = delimiter ,
894
907
null = null , header = header , quote = quote , escape = escape ,
895
908
force_not_null = force_not_null , force_null = force_null ,
896
909
encoding = encoding
897
910
)
898
911
899
- copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}' .format (
900
- tab = tabname , cols = cols , opts = opts )
912
+ copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond} ' .format (
913
+ tab = tabname , cols = cols , opts = opts , cond = cond )
901
914
902
915
return await self ._copy_in (copy_stmt , source , timeout )
903
916
904
917
async def copy_records_to_table (self , table_name , * , records ,
905
918
columns = None , schema_name = None ,
906
- timeout = None ):
919
+ timeout = None , where = None ):
907
920
"""Copy a list of records to the specified table using binary COPY.
908
921
909
922
:param str table_name:
@@ -920,6 +933,16 @@ async def copy_records_to_table(self, table_name, *, records,
920
933
:param str schema_name:
921
934
An optional schema name to qualify the table.
922
935
936
+ :param str where:
937
+ An optional condition used to filter rows when copying.
938
+
939
+ .. note::
940
+
941
+ Usage of this parameter requires support for the
942
+ ``COPY FROM ... WHERE`` syntax, introduced in
943
+ PostgreSQL version 12.
944
+
945
+
923
946
:param float timeout:
924
947
Optional timeout value in seconds.
925
948
@@ -964,6 +987,9 @@ async def copy_records_to_table(self, table_name, *, records,
964
987
965
988
.. versionchanged:: 0.24.0
966
989
The ``records`` argument may be an asynchronous iterable.
990
+
991
+ .. versionadded:: 0.27.0
992
+ Added ``where`` parameter.
967
993
"""
968
994
tabname = utils ._quote_ident (table_name )
969
995
if schema_name :
@@ -981,14 +1007,27 @@ async def copy_records_to_table(self, table_name, *, records,
981
1007
982
1008
intro_ps = await self ._prepare (intro_query , use_cache = True )
983
1009
1010
+ cond = self ._format_copy_where (where )
984
1011
opts = '(FORMAT binary)'
985
1012
986
- copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}' .format (
987
- tab = tabname , cols = cols , opts = opts )
1013
+ copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond} ' .format (
1014
+ tab = tabname , cols = cols , opts = opts , cond = cond )
988
1015
989
1016
return await self ._protocol .copy_in (
990
1017
copy_stmt , None , None , records , intro_ps ._state , timeout )
991
1018
1019
+ def _format_copy_where (self , where ):
1020
+ if where and not self ._server_caps .sql_copy_from_where :
1021
+ raise exceptions .UnsupportedServerFeatureError (
1022
+ 'the `where` parameter requires PostgreSQL 12 or later' )
1023
+
1024
+ if where :
1025
+ where_clause = 'WHERE ' + where
1026
+ else :
1027
+ where_clause = ''
1028
+
1029
+ return where_clause
1030
+
992
1031
def _format_copy_opts (self , * , format = None , oids = None , freeze = None ,
993
1032
delimiter = None , null = None , header = None , quote = None ,
994
1033
escape = None , force_quote = None , force_not_null = None ,
@@ -2370,7 +2409,7 @@ class _ConnectionProxy:
2370
2409
ServerCapabilities = collections .namedtuple (
2371
2410
'ServerCapabilities' ,
2372
2411
['advisory_locks' , 'notifications' , 'plpgsql' , 'sql_reset' ,
2373
- 'sql_close_all' ])
2412
+ 'sql_close_all' , 'sql_copy_from_where' ])
2374
2413
ServerCapabilities .__doc__ = 'PostgreSQL server capabilities.'
2375
2414
2376
2415
@@ -2382,27 +2421,31 @@ def _detect_server_capabilities(server_version, connection_settings):
2382
2421
plpgsql = False
2383
2422
sql_reset = True
2384
2423
sql_close_all = False
2424
+ sql_copy_from_where = False
2385
2425
elif hasattr (connection_settings , 'crdb_version' ):
2386
2426
# CockroachDB detected.
2387
2427
advisory_locks = False
2388
2428
notifications = False
2389
2429
plpgsql = False
2390
2430
sql_reset = False
2391
2431
sql_close_all = False
2432
+ sql_copy_from_where = False
2392
2433
elif hasattr (connection_settings , 'crate_version' ):
2393
2434
# CrateDB detected.
2394
2435
advisory_locks = False
2395
2436
notifications = False
2396
2437
plpgsql = False
2397
2438
sql_reset = False
2398
2439
sql_close_all = False
2440
+ sql_copy_from_where = False
2399
2441
else :
2400
2442
# Standard PostgreSQL server assumed.
2401
2443
advisory_locks = True
2402
2444
notifications = True
2403
2445
plpgsql = True
2404
2446
sql_reset = True
2405
2447
sql_close_all = True
2448
+ sql_copy_from_where = server_version .major >= 12
2406
2449
2407
2450
return ServerCapabilities (
2408
2451
advisory_locks = advisory_locks ,
0 commit comments