Skip to content

Commit 0edbab9

Browse files
authored
fix(postgres): use .transaction method instead of managing our own
1 parent 44151f0 commit 0edbab9

File tree

5 files changed

+63
-61
lines changed

5 files changed

+63
-61
lines changed

.github/workflows/ibis-backends.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ jobs:
141141
- default-libmysqlclient-dev
142142
- name: postgres
143143
title: PostgreSQL
144+
serial: true
144145
extras:
145146
- --extra postgres
146147
- --extra geospatial
@@ -150,6 +151,7 @@ jobs:
150151
- libgeos-dev
151152
- name: postgres
152153
title: PostgreSQL + Torch
154+
serial: true
153155
extras:
154156
- --extra postgres
155157
- --extra geospatial
@@ -319,6 +321,7 @@ jobs:
319321
backend:
320322
name: postgres
321323
title: PostgreSQL
324+
serial: true
322325
extras:
323326
- --extra postgres
324327
- --extra geospatial
@@ -339,6 +342,7 @@ jobs:
339342
backend:
340343
name: postgres
341344
title: PostgreSQL + Torch
345+
serial: true
342346
extras:
343347
- --extra postgres
344348
- --extra geospatial

ibis/backends/postgres/__init__.py

Lines changed: 45 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import pandas as pd
3535
import polars as pl
3636
import pyarrow as pa
37+
from typing_extensions import Self
3738

3839

3940
class NatDumper(psycopg.adapt.Dumper):
@@ -131,23 +132,17 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
131132
name, schema=schema, columns=True, placeholder="%s"
132133
)
133134

134-
with self.begin() as cur:
135-
cur.execute(create_stmt_sql)
136-
cur.executemany(sql, data)
135+
with self.begin() as cursor:
136+
cursor.execute(create_stmt_sql)
137+
cursor.executemany(sql, data)
137138

138139
@contextlib.contextmanager
139-
def begin(self):
140-
con = self.con
141-
cursor = con.cursor()
142-
try:
140+
def begin(self, *, name: str = "", withhold: bool = False):
141+
with (
142+
(con := self.con).transaction(),
143+
con.cursor(name=name, withhold=withhold) as cursor,
144+
):
143145
yield cursor
144-
except Exception:
145-
con.rollback()
146-
raise
147-
else:
148-
con.commit()
149-
finally:
150-
cursor.close()
151146

152147
def _fetch_from_cursor(
153148
self, cursor: psycopg.Cursor, schema: sch.Schema
@@ -190,6 +185,7 @@ def do_connect(
190185
port: int = 5432,
191186
database: str | None = None,
192187
schema: str | None = None,
188+
autocommit: bool = True,
193189
**kwargs: Any,
194190
) -> None:
195191
"""Create an Ibis client connected to PostgreSQL database.
@@ -208,6 +204,8 @@ def do_connect(
208204
Database to connect to
209205
schema
210206
PostgreSQL schema to use. If `None`, use the default `search_path`.
207+
autocommit
208+
Whether or not to autocommit
211209
kwargs
212210
Additional keyword arguments to pass to the backend client connection.
213211
@@ -252,6 +250,7 @@ def do_connect(
252250
password=password,
253251
dbname=database,
254252
options=(f"-csearch_path={schema}" * (schema is not None)) or None,
253+
autocommit=autocommit,
255254
**kwargs,
256255
)
257256

@@ -276,8 +275,23 @@ def from_connection(cls, con: psycopg.Connection, /) -> Backend:
276275
return new_backend
277276

278277
def _post_connect(self) -> None:
279-
with self.begin() as cur:
280-
cur.execute("SET TIMEZONE = UTC")
278+
import psycopg.types
279+
import psycopg.types.hstore
280+
281+
try:
282+
# try to load hstore
283+
with self.begin() as cursor:
284+
cursor.execute("CREATE EXTENSION IF NOT EXISTS hstore")
285+
psycopg.types.hstore.register_hstore(
286+
psycopg.types.TypeInfo.fetch(self.con, "hstore"), self.con
287+
)
288+
except psycopg.Error as e:
289+
warnings.warn(f"Failed to load hstore extension: {e}")
290+
except TypeError:
291+
pass
292+
293+
with self.begin() as cursor:
294+
cursor.execute("SET TIMEZONE = UTC")
281295

282296
@property
283297
def _session_temp_db(self) -> str | None:
@@ -706,23 +720,12 @@ def drop_table(
706720
pass
707721

708722
@contextlib.contextmanager
709-
def _safe_raw_sql(self, *args, **kwargs):
710-
with contextlib.closing(self.raw_sql(*args, **kwargs)) as result:
711-
yield result
723+
def _safe_raw_sql(self, query: str | sg.Expression, **kwargs: Any):
724+
with contextlib.suppress(AttributeError):
725+
query = query.sql(dialect=self.dialect)
712726

713-
def _register_hstore(self, cursor: psycopg.Cursor) -> None:
714-
import psycopg.types
715-
import psycopg.types.hstore
716-
717-
try:
718-
# try to load hstore
719-
psycopg.types.hstore.register_hstore(
720-
psycopg.types.TypeInfo.fetch(self.con, "hstore"), cursor
721-
)
722-
except psycopg.Error as e:
723-
warnings.warn(f"Failed to load hstore extension: {e}")
724-
except TypeError:
725-
pass
727+
with self.begin() as cursor:
728+
yield cursor.execute(query, **kwargs)
726729

727730
def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any:
728731
with contextlib.suppress(AttributeError):
@@ -731,16 +734,12 @@ def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any:
731734
con = self.con
732735
cursor = con.cursor()
733736

734-
self._register_hstore(cursor)
735-
736737
try:
737738
cursor.execute(query, **kwargs)
738739
except Exception:
739-
con.rollback()
740740
cursor.close()
741741
raise
742742
else:
743-
con.commit()
744743
return cursor
745744

746745
@util.experimental
@@ -757,33 +756,22 @@ def to_pyarrow_batches(
757756
import pandas as pd
758757
import pyarrow as pa
759758

760-
def _batches(*, schema: pa.Schema, con: psycopg.Connection, query: str):
759+
def _batches(self: Self, *, schema: pa.Schema, query: str):
761760
columns = schema.names
762761
# server-side cursors need to be uniquely named
763-
with con.cursor(name=util.gen_name("postgres_cursor")) as cursor:
764-
self._register_hstore(cursor)
765-
766-
try:
767-
cur = cursor.execute(query)
768-
except Exception:
769-
con.rollback()
770-
raise
771-
else:
772-
try:
773-
while batch := cur.fetchmany(chunk_size):
774-
yield pa.RecordBatch.from_pandas(
775-
pd.DataFrame(batch, columns=columns), schema=schema
776-
)
777-
except Exception:
778-
con.rollback()
779-
raise
780-
else:
781-
con.commit()
762+
with self.begin(
763+
name=util.gen_name("postgres_cursor"), withhold=True
764+
) as cursor:
765+
cursor.execute(query)
766+
while batch := cursor.fetchmany(chunk_size):
767+
yield pa.RecordBatch.from_pandas(
768+
pd.DataFrame(batch, columns=columns), schema=schema
769+
)
782770

783771
self._run_pre_execute_hooks(expr)
784772

785773
schema = expr.as_table().schema().to_pyarrow()
786774
query = self.compile(expr, limit=limit, params=params)
787775
return pa.RecordBatchReader.from_batches(
788-
schema, _batches(schema=schema, con=self.con, query=query)
776+
schema, _batches(self, schema=schema, query=query)
789777
)

ibis/backends/postgres/tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def connect(*, tmpdir, worker_id, **kw):
6464
user=PG_USER,
6565
password=PG_PASS,
6666
database=IBIS_TEST_POSTGRES_DB,
67+
application_name="Ibis test suite",
6768
**kw,
6869
)
6970

ibis/backends/postgres/tests/test_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,10 @@ def test_kwargs_passthrough_in_connect():
255255
con = ibis.connect(
256256
"postgresql://postgres:postgres@localhost:5432/ibis_testing?sslmode=allow"
257257
)
258-
assert con.current_catalog == "ibis_testing"
258+
try:
259+
assert con.current_catalog == "ibis_testing"
260+
finally:
261+
con.disconnect()
259262

260263

261264
def test_port():

ibis/backends/tests/test_client.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -819,8 +819,10 @@ def test_unsigned_integer_type(con, temp_table):
819819
)
820820
def test_connect_url(url):
821821
con = ibis.connect(url)
822-
one = ibis.literal(1)
823-
assert con.execute(one) == 1
822+
try:
823+
assert con.execute(ibis.literal(1)) == 1
824+
finally:
825+
con.disconnect()
824826

825827

826828
@pytest.mark.parametrize(
@@ -1295,7 +1297,11 @@ def test_set_backend_url(url, monkeypatch):
12951297
monkeypatch.setattr(ibis.options, "default_backend", None)
12961298
name = url.split("://")[0]
12971299
ibis.set_backend(url)
1298-
assert ibis.get_backend().name == name
1300+
con = ibis.get_backend()
1301+
try:
1302+
assert con.name == name
1303+
finally:
1304+
con.disconnect()
12991305

13001306

13011307
@pytest.mark.notyet(

0 commit comments

Comments
 (0)