34
34
import pandas as pd
35
35
import polars as pl
36
36
import pyarrow as pa
37
+ from typing_extensions import Self
37
38
38
39
39
40
class NatDumper (psycopg .adapt .Dumper ):
@@ -131,23 +132,17 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
131
132
name , schema = schema , columns = True , placeholder = "%s"
132
133
)
133
134
134
- with self .begin () as cur :
135
- cur .execute (create_stmt_sql )
136
- cur .executemany (sql , data )
135
+ with self .begin () as cursor :
136
+ cursor .execute (create_stmt_sql )
137
+ cursor .executemany (sql , data )
137
138
138
139
@contextlib .contextmanager
139
- def begin (self ):
140
- con = self .con
141
- cursor = con .cursor ()
142
- try :
140
+ def begin (self , * , name : str = "" , withhold : bool = False ):
141
+ with (
142
+ (con := self .con ).transaction (),
143
+ con .cursor (name = name , withhold = withhold ) as cursor ,
144
+ ):
143
145
yield cursor
144
- except Exception :
145
- con .rollback ()
146
- raise
147
- else :
148
- con .commit ()
149
- finally :
150
- cursor .close ()
151
146
152
147
def _fetch_from_cursor (
153
148
self , cursor : psycopg .Cursor , schema : sch .Schema
@@ -190,6 +185,7 @@ def do_connect(
190
185
port : int = 5432 ,
191
186
database : str | None = None ,
192
187
schema : str | None = None ,
188
+ autocommit : bool = True ,
193
189
** kwargs : Any ,
194
190
) -> None :
195
191
"""Create an Ibis client connected to PostgreSQL database.
@@ -208,6 +204,8 @@ def do_connect(
208
204
Database to connect to
209
205
schema
210
206
PostgreSQL schema to use. If `None`, use the default `search_path`.
207
+ autocommit
208
+ Whether or not to autocommit
211
209
kwargs
212
210
Additional keyword arguments to pass to the backend client connection.
213
211
@@ -252,6 +250,7 @@ def do_connect(
252
250
password = password ,
253
251
dbname = database ,
254
252
options = (f"-csearch_path={ schema } " * (schema is not None )) or None ,
253
+ autocommit = autocommit ,
255
254
** kwargs ,
256
255
)
257
256
@@ -276,8 +275,23 @@ def from_connection(cls, con: psycopg.Connection, /) -> Backend:
276
275
return new_backend
277
276
278
277
def _post_connect (self ) -> None :
279
- with self .begin () as cur :
280
- cur .execute ("SET TIMEZONE = UTC" )
278
+ import psycopg .types
279
+ import psycopg .types .hstore
280
+
281
+ try :
282
+ # try to load hstore
283
+ with self .begin () as cursor :
284
+ cursor .execute ("CREATE EXTENSION IF NOT EXISTS hstore" )
285
+ psycopg .types .hstore .register_hstore (
286
+ psycopg .types .TypeInfo .fetch (self .con , "hstore" ), self .con
287
+ )
288
+ except psycopg .Error as e :
289
+ warnings .warn (f"Failed to load hstore extension: { e } " )
290
+ except TypeError :
291
+ pass
292
+
293
+ with self .begin () as cursor :
294
+ cursor .execute ("SET TIMEZONE = UTC" )
281
295
282
296
@property
283
297
def _session_temp_db (self ) -> str | None :
@@ -706,23 +720,12 @@ def drop_table(
706
720
pass
707
721
708
722
@contextlib .contextmanager
709
- def _safe_raw_sql (self , * args , ** kwargs ):
710
- with contextlib .closing ( self . raw_sql ( * args , ** kwargs )) as result :
711
- yield result
723
+ def _safe_raw_sql (self , query : str | sg . Expression , ** kwargs : Any ):
724
+ with contextlib .suppress ( AttributeError ) :
725
+ query = query . sql ( dialect = self . dialect )
712
726
713
- def _register_hstore (self , cursor : psycopg .Cursor ) -> None :
714
- import psycopg .types
715
- import psycopg .types .hstore
716
-
717
- try :
718
- # try to load hstore
719
- psycopg .types .hstore .register_hstore (
720
- psycopg .types .TypeInfo .fetch (self .con , "hstore" ), cursor
721
- )
722
- except psycopg .Error as e :
723
- warnings .warn (f"Failed to load hstore extension: { e } " )
724
- except TypeError :
725
- pass
727
+ with self .begin () as cursor :
728
+ yield cursor .execute (query , ** kwargs )
726
729
727
730
def raw_sql (self , query : str | sg .Expression , ** kwargs : Any ) -> Any :
728
731
with contextlib .suppress (AttributeError ):
@@ -731,16 +734,12 @@ def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any:
731
734
con = self .con
732
735
cursor = con .cursor ()
733
736
734
- self ._register_hstore (cursor )
735
-
736
737
try :
737
738
cursor .execute (query , ** kwargs )
738
739
except Exception :
739
- con .rollback ()
740
740
cursor .close ()
741
741
raise
742
742
else :
743
- con .commit ()
744
743
return cursor
745
744
746
745
@util .experimental
@@ -757,33 +756,22 @@ def to_pyarrow_batches(
757
756
import pandas as pd
758
757
import pyarrow as pa
759
758
760
- def _batches (* , schema : pa .Schema , con : psycopg . Connection , query : str ):
759
+ def _batches (self : Self , * , schema : pa .Schema , query : str ):
761
760
columns = schema .names
762
761
# server-side cursors need to be uniquely named
763
- with con .cursor (name = util .gen_name ("postgres_cursor" )) as cursor :
764
- self ._register_hstore (cursor )
765
-
766
- try :
767
- cur = cursor .execute (query )
768
- except Exception :
769
- con .rollback ()
770
- raise
771
- else :
772
- try :
773
- while batch := cur .fetchmany (chunk_size ):
774
- yield pa .RecordBatch .from_pandas (
775
- pd .DataFrame (batch , columns = columns ), schema = schema
776
- )
777
- except Exception :
778
- con .rollback ()
779
- raise
780
- else :
781
- con .commit ()
762
+ with self .begin (
763
+ name = util .gen_name ("postgres_cursor" ), withhold = True
764
+ ) as cursor :
765
+ cursor .execute (query )
766
+ while batch := cursor .fetchmany (chunk_size ):
767
+ yield pa .RecordBatch .from_pandas (
768
+ pd .DataFrame (batch , columns = columns ), schema = schema
769
+ )
782
770
783
771
self ._run_pre_execute_hooks (expr )
784
772
785
773
schema = expr .as_table ().schema ().to_pyarrow ()
786
774
query = self .compile (expr , limit = limit , params = params )
787
775
return pa .RecordBatchReader .from_batches (
788
- schema , _batches (schema = schema , con = self . con , query = query )
776
+ schema , _batches (self , schema = schema , query = query )
789
777
)
0 commit comments