Skip to content

Commit 770bbe5

Browse files
redgoldlaceelprans
authored andcommitted
Add support for WHERE clause in copy_to methods
1 parent 89d5bd0 commit 770bbe5

File tree

3 files changed

+64
-12
lines changed

3 files changed

+64
-12
lines changed

asyncpg/connection.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,7 @@ async def copy_to_table(self, table_name, *, source,
832832
delimiter=None, null=None, header=None,
833833
quote=None, escape=None, force_quote=None,
834834
force_not_null=None, force_null=None,
835-
encoding=None):
835+
encoding=None, where=None):
836836
"""Copy data to the specified table.
837837
838838
:param str table_name:
@@ -851,6 +851,15 @@ async def copy_to_table(self, table_name, *, source,
851851
:param str schema_name:
852852
An optional schema name to qualify the table.
853853
854+
:param str where:
855+
An optional condition used to filter rows when copying.
856+
857+
.. note::
858+
859+
Usage of this parameter requires support for the
860+
``COPY FROM ... WHERE`` syntax, introduced in
861+
PostgreSQL version 12.
862+
854863
:param float timeout:
855864
Optional timeout value in seconds.
856865
@@ -878,6 +887,9 @@ async def copy_to_table(self, table_name, *, source,
878887
https://www.postgresql.org/docs/current/static/sql-copy.html
879888
880889
.. versionadded:: 0.11.0
890+
891+
.. versionadded:: 0.27.0
892+
Added ``where`` parameter.
881893
"""
882894
tabname = utils._quote_ident(table_name)
883895
if schema_name:
@@ -889,21 +901,22 @@ async def copy_to_table(self, table_name, *, source,
889901
else:
890902
cols = ''
891903

904+
cond = self._format_copy_where(where)
892905
opts = self._format_copy_opts(
893906
format=format, oids=oids, freeze=freeze, delimiter=delimiter,
894907
null=null, header=header, quote=quote, escape=escape,
895908
force_not_null=force_not_null, force_null=force_null,
896909
encoding=encoding
897910
)
898911

899-
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
900-
tab=tabname, cols=cols, opts=opts)
912+
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond}'.format(
913+
tab=tabname, cols=cols, opts=opts, cond=cond)
901914

902915
return await self._copy_in(copy_stmt, source, timeout)
903916

904917
async def copy_records_to_table(self, table_name, *, records,
905918
columns=None, schema_name=None,
906-
timeout=None):
919+
timeout=None, where=None):
907920
"""Copy a list of records to the specified table using binary COPY.
908921
909922
:param str table_name:
@@ -920,6 +933,16 @@ async def copy_records_to_table(self, table_name, *, records,
920933
:param str schema_name:
921934
An optional schema name to qualify the table.
922935
936+
:param str where:
937+
An optional condition used to filter rows when copying.
938+
939+
.. note::
940+
941+
Usage of this parameter requires support for the
942+
``COPY FROM ... WHERE`` syntax, introduced in
943+
PostgreSQL version 12.
944+
945+
923946
:param float timeout:
924947
Optional timeout value in seconds.
925948
@@ -964,6 +987,9 @@ async def copy_records_to_table(self, table_name, *, records,
964987
965988
.. versionchanged:: 0.24.0
966989
The ``records`` argument may be an asynchronous iterable.
990+
991+
.. versionadded:: 0.27.0
992+
Added ``where`` parameter.
967993
"""
968994
tabname = utils._quote_ident(table_name)
969995
if schema_name:
@@ -981,14 +1007,27 @@ async def copy_records_to_table(self, table_name, *, records,
9811007

9821008
intro_ps = await self._prepare(intro_query, use_cache=True)
9831009

1010+
cond = self._format_copy_where(where)
9841011
opts = '(FORMAT binary)'
9851012

986-
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
987-
tab=tabname, cols=cols, opts=opts)
1013+
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond}'.format(
1014+
tab=tabname, cols=cols, opts=opts, cond=cond)
9881015

9891016
return await self._protocol.copy_in(
9901017
copy_stmt, None, None, records, intro_ps._state, timeout)
9911018

1019+
def _format_copy_where(self, where):
1020+
if where and not self._server_caps.sql_copy_from_where:
1021+
raise exceptions.UnsupportedServerFeatureError(
1022+
'the `where` parameter requires PostgreSQL 12 or later')
1023+
1024+
if where:
1025+
where_clause = 'WHERE ' + where
1026+
else:
1027+
where_clause = ''
1028+
1029+
return where_clause
1030+
9921031
def _format_copy_opts(self, *, format=None, oids=None, freeze=None,
9931032
delimiter=None, null=None, header=None, quote=None,
9941033
escape=None, force_quote=None, force_not_null=None,
@@ -2370,7 +2409,7 @@ class _ConnectionProxy:
23702409
ServerCapabilities = collections.namedtuple(
23712410
'ServerCapabilities',
23722411
['advisory_locks', 'notifications', 'plpgsql', 'sql_reset',
2373-
'sql_close_all'])
2412+
'sql_close_all', 'sql_copy_from_where'])
23742413
ServerCapabilities.__doc__ = 'PostgreSQL server capabilities.'
23752414

23762415

@@ -2382,27 +2421,31 @@ def _detect_server_capabilities(server_version, connection_settings):
23822421
plpgsql = False
23832422
sql_reset = True
23842423
sql_close_all = False
2424+
sql_copy_from_where = False
23852425
elif hasattr(connection_settings, 'crdb_version'):
23862426
# CockroachDB detected.
23872427
advisory_locks = False
23882428
notifications = False
23892429
plpgsql = False
23902430
sql_reset = False
23912431
sql_close_all = False
2432+
sql_copy_from_where = False
23922433
elif hasattr(connection_settings, 'crate_version'):
23932434
# CrateDB detected.
23942435
advisory_locks = False
23952436
notifications = False
23962437
plpgsql = False
23972438
sql_reset = False
23982439
sql_close_all = False
2440+
sql_copy_from_where = False
23992441
else:
24002442
# Standard PostgreSQL server assumed.
24012443
advisory_locks = True
24022444
notifications = True
24032445
plpgsql = True
24042446
sql_reset = True
24052447
sql_close_all = True
2448+
sql_copy_from_where = server_version.major >= 12
24062449

24072450
return ServerCapabilities(
24082451
advisory_locks=advisory_locks,

asyncpg/exceptions/_base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212

1313
__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
1414
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage',
15+
'ClientConfigurationError',
1516
'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError',
1617
'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched',
17-
'ClientConfigurationError')
18+
'UnsupportedServerFeatureError')
1819

1920

2021
def _is_asyncpg_class(cls):
@@ -233,6 +234,10 @@ class UnsupportedClientFeatureError(InterfaceError):
233234
"""Requested feature is unsupported by asyncpg."""
234235

235236

237+
class UnsupportedServerFeatureError(InterfaceError):
238+
"""Requested feature is unsupported by PostgreSQL server."""
239+
240+
236241
class InterfaceWarning(InterfaceMessage, UserWarning):
237242
"""A warning caused by an improper use of asyncpg API."""
238243

asyncpg/pool.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,8 @@ async def copy_to_table(
739739
force_quote=None,
740740
force_not_null=None,
741741
force_null=None,
742-
encoding=None
742+
encoding=None,
743+
where=None
743744
):
744745
"""Copy data to the specified table.
745746
@@ -768,7 +769,8 @@ async def copy_to_table(
768769
force_quote=force_quote,
769770
force_not_null=force_not_null,
770771
force_null=force_null,
771-
encoding=encoding
772+
encoding=encoding,
773+
where=where
772774
)
773775

774776
async def copy_records_to_table(
@@ -778,7 +780,8 @@ async def copy_records_to_table(
778780
records,
779781
columns=None,
780782
schema_name=None,
781-
timeout=None
783+
timeout=None,
784+
where=None
782785
):
783786
"""Copy a list of records to the specified table using binary COPY.
784787
@@ -795,7 +798,8 @@ async def copy_records_to_table(
795798
records=records,
796799
columns=columns,
797800
schema_name=schema_name,
798-
timeout=timeout
801+
timeout=timeout,
802+
where=where
799803
)
800804

801805
def acquire(self, *, timeout=None):

0 commit comments

Comments
 (0)