Skip to content

Commit 6f9d141

Browse files
committed
fix(postgres): clean up possible transaction hangs
1 parent bd0fa0b commit 6f9d141

File tree

3 files changed

+97
-71
lines changed

3 files changed

+97
-71
lines changed

ibis/backends/postgres/__init__.py

Lines changed: 81 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -134,34 +134,58 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
134134

135135
con = self.con
136136
with con.cursor() as cursor, con.transaction():
137-
cursor.execute(create_stmt_sql)
138-
cursor.executemany(sql, data)
137+
cursor.execute(create_stmt_sql).executemany(sql, data)
139138

140139
@contextlib.contextmanager
141140
def begin(self):
142141
with (con := self.con).cursor() as cursor, con.transaction():
143142
yield cursor
144143

145-
def _fetch_from_cursor(
146-
self, cursor: psycopg.Cursor, schema: sch.Schema
147-
) -> pd.DataFrame:
144+
def execute(
145+
self,
146+
expr: ir.Expr,
147+
/,
148+
*,
149+
params: Mapping[ir.Scalar, Any] | None = None,
150+
limit: int | str | None = None,
151+
**kwargs: Any,
152+
) -> pd.DataFrame | pd.Series | Any:
153+
"""Execute an Ibis expression and return a pandas `DataFrame`, `Series`, or scalar.
154+
155+
Parameters
156+
----------
157+
expr
158+
Ibis expression to execute.
159+
params
160+
Mapping of scalar parameter expressions to value.
161+
limit
162+
An integer to effect a specific row limit. A value of `None` means
163+
no limit. The default is in `ibis/config.py`.
164+
kwargs
165+
Keyword arguments
166+
167+
Returns
168+
-------
169+
DataFrame | Series | scalar
170+
The result of the expression execution.
171+
"""
148172
import pandas as pd
149173

150174
from ibis.backends.postgres.converter import PostgresPandasData
151175

152-
try:
153-
df = pd.DataFrame.from_records(
154-
cursor.fetchall(), columns=schema.names, coerce_float=True
155-
)
156-
except Exception:
157-
# clean up the cursor if we fail to create the DataFrame
158-
#
159-
# in the sqlite case failing to close the cursor results in
160-
# artificially locked tables
161-
cursor.close()
162-
raise
176+
self._run_pre_execute_hooks(expr)
177+
178+
table = expr.as_table()
179+
sql = self.compile(table, params=params, limit=limit, **kwargs)
180+
181+
con = self.con
182+
with con.cursor() as cur, con.transaction():
183+
rows = cur.execute(sql).fetchall()
184+
185+
schema = table.schema()
186+
df = pd.DataFrame.from_records(rows, columns=schema.names, coerce_float=True)
163187
df = PostgresPandasData.convert_table(df, schema)
164-
return df
188+
return expr.__pandas_result__(df)
165189

166190
@property
167191
def version(self):
@@ -352,43 +376,34 @@ def list_tables(
352376
catalog = catalog.sql(dialect=self.name)
353377
conditions.append(C.table_catalog.eq(sge.convert(catalog)))
354378

355-
sql = (
356-
sg.select("table_name")
379+
sg_expr = (
380+
sg.select(C.table_name)
357381
.from_(sg.table("tables", db="information_schema"))
358382
.distinct()
359383
.where(*conditions)
360-
.sql(self.dialect)
361384
)
362385

363-
con = self.con
364-
with con.cursor() as cursor, con.transaction():
365-
out = cursor.execute(sql).fetchall()
366-
367-
# Include temporary tables only if no database has been explicitly specified
368-
# to avoid temp tables showing up in all calls to `list_tables`
386+
# Include temporary tables only if no database has been explicitly
387+
# specified to avoid temp tables showing up in all calls to
388+
# `list_tables`
369389
if db == "public":
370-
out += self._fetch_temp_tables()
371-
372-
return self._filter_with_like(map(itemgetter(0), out), like)
373-
374-
def _fetch_temp_tables(self):
375-
# postgres temporary tables are stored in a separate schema
376-
# so we need to independently grab them and return them along with
377-
# the existing results
378-
379-
sql = (
380-
sg.select("table_name")
381-
.from_(sg.table("tables", db="information_schema"))
382-
.distinct()
383-
.where(C.table_type.eq(sge.convert("LOCAL TEMPORARY")))
384-
.sql(self.dialect)
385-
)
390+
# postgres temporary tables are stored in a separate schema so we need
391+
# to independently grab them and return them along with the existing
392+
# results
393+
sg_expr = sg_expr.union(
394+
sg.select(C.table_name)
395+
.from_(sg.table("tables", db="information_schema"))
396+
.distinct()
397+
.where(C.table_type.eq(sge.convert("LOCAL TEMPORARY"))),
398+
distinct=False,
399+
)
386400

401+
sql = sg_expr.sql(self.dialect)
387402
con = self.con
388403
with con.cursor() as cursor, con.transaction():
389404
out = cursor.execute(sql).fetchall()
390405

391-
return out
406+
return self._filter_with_like(map(itemgetter(0), out), like)
392407

393408
def list_catalogs(self, *, like: str | None = None) -> list[str]:
394409
# http://dba.stackexchange.com/a/1304/58517
@@ -400,9 +415,9 @@ def list_catalogs(self, *, like: str | None = None) -> list[str]:
400415
)
401416
con = self.con
402417
with con.cursor() as cursor, con.transaction():
403-
catalogs = list(map(itemgetter(0), cursor.execute(cats)))
418+
catalogs = cursor.execute(cats).fetchall()
404419

405-
return self._filter_with_like(catalogs, like)
420+
return self._filter_with_like(map(itemgetter(0), catalogs), like)
406421

407422
def list_databases(
408423
self, *, like: str | None = None, catalog: str | None = None
@@ -414,24 +429,24 @@ def list_databases(
414429
)
415430
con = self.con
416431
with con.cursor() as cursor, con.transaction():
417-
databases = list(map(itemgetter(0), cursor.execute(dbs)))
432+
databases = cursor.execute(dbs).fetchall()
418433

419-
return self._filter_with_like(databases, like)
434+
return self._filter_with_like(map(itemgetter(0), databases), like)
420435

421436
@property
422437
def current_catalog(self) -> str:
423438
sql = sg.select(sg.func("current_database")).sql(self.dialect)
424439
con = self.con
425440
with con.cursor() as cursor, con.transaction():
426-
(db,) = cursor.execute(sql).fetchone()
441+
[(db,)] = cursor.execute(sql).fetchall()
427442
return db
428443

429444
@property
430445
def current_database(self) -> str:
431446
sql = sg.select(sg.func("current_schema")).sql(self.dialect)
432447
con = self.con
433448
with con.cursor() as cursor, con.transaction():
434-
(schema,) = cursor.execute(sql).fetchone()
449+
[(schema,)] = cursor.execute(sql).fetchall()
435450
return schema
436451

437452
def function(self, name: str, *, database: str | None = None) -> Callable:
@@ -698,20 +713,20 @@ def create_table(
698713
this_no_catalog = sg.table(name, quoted=quoted)
699714

700715
con = self.con
701-
with con.cursor() as cursor, con.transaction():
702-
cursor.execute(create_stmt)
716+
stmts = [create_stmt]
703717

704-
if query is not None:
705-
insert_stmt = sge.Insert(this=table_expr, expression=query).sql(dialect)
706-
cursor.execute(insert_stmt)
718+
if query is not None:
719+
stmts.append(sge.Insert(this=table_expr, expression=query).sql(dialect))
707720

708-
if overwrite:
709-
cursor.execute(
710-
sge.Drop(kind="TABLE", this=this, exists=True).sql(dialect)
711-
)
712-
cursor.execute(
713-
f"ALTER TABLE IF EXISTS {table_expr.sql(dialect)} RENAME TO {this_no_catalog.sql(dialect)}"
714-
)
721+
if overwrite:
722+
stmts.append(sge.Drop(kind="TABLE", this=this, exists=True).sql(dialect))
723+
stmts.append(
724+
f"ALTER TABLE IF EXISTS {table_expr.sql(dialect)} RENAME TO {this_no_catalog.sql(dialect)}"
725+
)
726+
727+
with con.cursor() as cursor, con.transaction():
728+
for stmt in stmts:
729+
cursor.execute(stmt)
715730

716731
if schema is None:
717732
return self.table(name, database=database)
@@ -743,7 +758,8 @@ def _safe_raw_sql(self, query: str | sg.Expression, **kwargs: Any):
743758
with contextlib.suppress(AttributeError):
744759
query = query.sql(dialect=self.dialect)
745760

746-
with (con := self.con).cursor() as cursor, con.transaction():
761+
con = self.con
762+
with con.cursor() as cursor, con.transaction():
747763
yield cursor.execute(query, **kwargs)
748764

749765
def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any:
@@ -757,6 +773,7 @@ def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any:
757773
cursor.execute(query, **kwargs)
758774
except Exception:
759775
cursor.close()
776+
con.rollback()
760777
raise
761778
else:
762779
return cursor
@@ -783,8 +800,8 @@ def _batches(self: Self, *, schema: pa.Schema, query: str):
783800
con.cursor(name=util.gen_name("postgres_cursor")) as cursor,
784801
con.transaction(),
785802
):
786-
cursor.execute(query)
787-
while batch := cursor.fetchmany(chunk_size):
803+
cur = cursor.execute(query)
804+
while batch := cur.fetchmany(chunk_size):
788805
yield pa.RecordBatch.from_pandas(
789806
pd.DataFrame(batch, columns=columns), schema=schema
790807
)

ibis/backends/postgres/tests/conftest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def connect(*, tmpdir, worker_id, **kw):
6464
user=PG_USER,
6565
password=PG_PASS,
6666
database=IBIS_TEST_POSTGRES_DB,
67-
application_name="Ibis test suite",
6867
**kw,
6968
)
7069

ibis/backends/postgres/tests/test_client.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from ibis.backends.tests.errors import PsycoPgOperationalError
3434
from ibis.util import gen_name
3535

36-
pytest.importorskip("psycopg")
36+
psycopg = pytest.importorskip("psycopg")
3737

3838
POSTGRES_TEST_DB = os.environ.get("IBIS_TEST_POSTGRES_DATABASE", "ibis_testing")
3939
IBIS_POSTGRES_HOST = os.environ.get("IBIS_TEST_POSTGRES_HOST", "localhost")
@@ -59,6 +59,17 @@ def test_literal_execute(con):
5959
assert result == "1234"
6060

6161

62+
def test_raw_sql(con):
63+
with con.raw_sql("SELECT 1 AS foo") as cur:
64+
assert cur.fetchall() == [(1,)]
65+
con.con.commit()
66+
with (
67+
pytest.raises(psycopg.errors.UndefinedTable),
68+
con.raw_sql("SELECT foo FROM bar"),
69+
):
70+
pass
71+
72+
6273
def test_simple_aggregate_execute(alltypes):
6374
d = alltypes.double_col.sum()
6475
v = d.execute()
@@ -288,10 +299,9 @@ def test_pgvector_type_load(con, vector_size):
288299
result = ["[1,2,3]", "[4,5,6]"]
289300
assert t.to_pyarrow().column("embedding").to_pylist() == result
290301

291-
query = f"""
292-
DROP TABLE IF EXISTS itemsvrandom;
293-
CREATE TABLE itemsvrandom (id bigserial PRIMARY KEY, embedding vector({vector_size}));
294-
"""
302+
query = f"""\
303+
DROP TABLE IF EXISTS itemsvrandom;
304+
CREATE TABLE itemsvrandom (id bigserial PRIMARY KEY, embedding vector({vector_size}))"""
295305

296306
with con._safe_raw_sql(query):
297307
pass
@@ -427,7 +437,7 @@ def test_create_geospatial_table_with_srid(con):
427437
)
428438

429439

430-
@pytest.fixture(scope="module")
440+
@pytest.fixture
431441
def enum_table(con):
432442
name = gen_name("enum_table")
433443
with con._safe_raw_sql("CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy')") as cur:

0 commit comments

Comments
 (0)