diff --git a/src/include/postgres_connection.hpp b/src/include/postgres_connection.hpp index 3352a847..8cb89ce5 100644 --- a/src/include/postgres_connection.hpp +++ b/src/include/postgres_connection.hpp @@ -55,6 +55,8 @@ class PostgresConnection { void Execute(const string &query); unique_ptr Query(const string &query); + PostgresVersion GetPostgresVersion(); + vector GetIndexInfo(const string &table_name); void BeginCopyTo(ClientContext &context, PostgresCopyState &state, PostgresCopyFormat format, diff --git a/src/include/postgres_scanner.hpp b/src/include/postgres_scanner.hpp index 80ed9f29..ac2e9a32 100644 --- a/src/include/postgres_scanner.hpp +++ b/src/include/postgres_scanner.hpp @@ -60,7 +60,7 @@ class PostgresScanFunction : public TableFunction { public: PostgresScanFunction(); - static void PrepareBind(ClientContext &context, PostgresBindData &bind); + static void PrepareBind(PostgresVersion version, ClientContext &context, PostgresBindData &bind); }; class PostgresScanFunctionFilterPushdown : public TableFunction { diff --git a/src/include/postgres_utils.hpp b/src/include/postgres_utils.hpp index 9b3a4658..de815841 100644 --- a/src/include/postgres_utils.hpp +++ b/src/include/postgres_utils.hpp @@ -10,6 +10,7 @@ #include "duckdb.hpp" #include +#include "postgres_version.hpp" namespace duckdb { class PostgresSchemaEntry; @@ -44,6 +45,8 @@ class PostgresUtils { static bool SupportedPostgresOid(const LogicalType &input); static LogicalType RemoveAlias(const LogicalType &type); static PostgresType CreateEmptyPostgresType(const LogicalType &type); + + static PostgresVersion ExtractPostgresVersion(const string &version); }; } // namespace duckdb diff --git a/src/include/postgres_version.hpp b/src/include/postgres_version.hpp new file mode 100644 index 00000000..f57c95a0 --- /dev/null +++ b/src/include/postgres_version.hpp @@ -0,0 +1,51 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// postgres_version.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" + +namespace duckdb { + +struct PostgresVersion { + PostgresVersion() { + } + PostgresVersion(idx_t major_v, idx_t minor_v, idx_t patch_v = 0) : major_v(major_v), minor_v(minor_v), patch_v(patch_v) { + } + + idx_t major_v = 0; + idx_t minor_v = 0; + idx_t patch_v = 0; + + inline bool operator<(const PostgresVersion &rhs) const { + if (major_v < rhs.major_v) { + return true; + } + if (major_v > rhs.major_v) { + return false; + } + if (minor_v < rhs.minor_v) { + return true; + } + if (minor_v > rhs.minor_v) { + return false; + } + return patch_v < rhs.patch_v; + }; + inline bool operator<=(const PostgresVersion &rhs) const { + return !(rhs < *this); + }; + inline bool operator>(const PostgresVersion &rhs) const { + return rhs < *this; + }; + inline bool operator>=(const PostgresVersion &rhs) const { + return !(*this < rhs); + }; +}; + +} // namespace duckdb diff --git a/src/include/storage/postgres_catalog.hpp b/src/include/storage/postgres_catalog.hpp index bfda382c..a7975b71 100644 --- a/src/include/storage/postgres_catalog.hpp +++ b/src/include/storage/postgres_catalog.hpp @@ -19,7 +19,8 @@ class PostgresSchemaEntry; class PostgresCatalog : public Catalog { public: - explicit PostgresCatalog(AttachedDatabase &db_p, const string &path, AccessMode access_mode); + explicit PostgresCatalog(PostgresVersion version, AttachedDatabase &db_p, const string &path, + AccessMode access_mode); ~PostgresCatalog(); string path; @@ -52,6 +53,10 @@ class PostgresCatalog : public Catalog { DatabaseSize GetDatabaseSize(ClientContext &context) override; + PostgresVersion GetPostgresVersion() const { + return version; + } + //! Label all postgres scans in the sub-tree as requiring materialization //! This is used for e.g. insert queries that have both (1) a scan from a postgres table, and (2) a sink into one static void MaterializePostgresScans(PhysicalOperator &op); @@ -70,6 +75,7 @@ class PostgresCatalog : public Catalog { void DropSchema(ClientContext &context, DropInfo &info) override; private: + PostgresVersion version; PostgresSchemaSet schemas; PostgresConnectionPool connection_pool; }; diff --git a/src/postgres_connection.cpp b/src/postgres_connection.cpp index 4e8963c4..3d82b711 100644 --- a/src/postgres_connection.cpp +++ b/src/postgres_connection.cpp @@ -66,6 +66,12 @@ void PostgresConnection::Execute(const string &query) { Query(query); } +PostgresVersion PostgresConnection::GetPostgresVersion() { + auto result = Query("SHOW server_version;"); + auto version = PostgresUtils::ExtractPostgresVersion(result->GetString(0, 0)); + return version; +} + bool PostgresConnection::IsOpen() { return connection.get(); } diff --git a/src/postgres_scanner.cpp b/src/postgres_scanner.cpp index 8ba8a499..6c27d027 100644 --- a/src/postgres_scanner.cpp +++ b/src/postgres_scanner.cpp @@ -45,16 +45,25 @@ struct PostgresGlobalState : public GlobalTableFunctionState { } }; -void PostgresScanFunction::PrepareBind(ClientContext &context, PostgresBindData &bind_data) { +void PostgresScanFunction::PrepareBind(PostgresVersion version, ClientContext &context, PostgresBindData &bind_data) { // we create a transaction here, and get the snapshot id so the parallel // reader threads can use the same snapshot auto &con = bind_data.connection; - auto result = con.Query("SELECT pg_is_in_recovery(), pg_export_snapshot()"); - bind_data.in_recovery = result->GetBool(0, 0); - bind_data.snapshot = ""; - - if (!bind_data.in_recovery) { - bind_data.snapshot = result->GetString(0, 1); + // pg_stat_wal_receiver was introduced in PostgreSQL 9.6 + if (version >= PostgresVersion(9, 6, 0)) { + auto result = con.Query("SELECT pg_is_in_recovery(), pg_export_snapshot(), (select count(*) from pg_stat_wal_receiver)"); + bind_data.in_recovery = result->GetBool(0, 0) || result->GetInt64(0, 2) > 0; + bind_data.snapshot = ""; + if (!bind_data.in_recovery) { + bind_data.snapshot = result->GetString(0, 1); + } + } else { + auto result = con.Query("SELECT pg_is_in_recovery(), pg_export_snapshot()"); + bind_data.in_recovery = result->GetBool(0, 0); + bind_data.snapshot = ""; + if (!bind_data.in_recovery) { + bind_data.snapshot = result->GetString(0, 1); + } } Value pages_per_task; @@ -85,7 +94,8 @@ static unique_ptr PostgresBind(ClientContext &context, TableFuncti bind_data->connection = PostgresConnection::Open(bind_data->dsn); bind_data->connection.Execute("BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ READ ONLY"); - PostgresScanFunction::PrepareBind(context, *bind_data); + auto version = bind_data->connection.GetPostgresVersion(); + PostgresScanFunction::PrepareBind(version, context, *bind_data); // query the table schema so we can interpret the bits in the pages auto info = PostgresTableSet::GetTableInfo(bind_data->connection, bind_data->schema_name, bind_data->table_name); diff --git a/src/postgres_storage.cpp b/src/postgres_storage.cpp index 7297d9d2..89c688f7 100644 --- a/src/postgres_storage.cpp +++ b/src/postgres_storage.cpp @@ -9,7 +9,9 @@ namespace duckdb { static unique_ptr PostgresAttach(StorageExtensionInfo *storage_info, AttachedDatabase &db, const string &name, AttachInfo &info, AccessMode access_mode) { - return make_uniq(db, info.path, access_mode); + auto connection = PostgresConnection::Open(info.path); + auto version = connection.GetPostgresVersion(); + return make_uniq(version, db, info.path, access_mode); } static unique_ptr PostgresCreateTransactionManager(StorageExtensionInfo *storage_info, diff --git a/src/postgres_utils.cpp b/src/postgres_utils.cpp index a8187210..67b55498 100644 --- a/src/postgres_utils.cpp +++ b/src/postgres_utils.cpp @@ -325,4 +325,44 @@ uint32_t PostgresUtils::ToPostgresOid(const LogicalType &input) { } } +PostgresVersion PostgresUtils::ExtractPostgresVersion(const string &version_str) { + PostgresVersion result; + idx_t pos = 0; + // scan for the first digit + while(pos < version_str.size() && !StringUtil::CharacterIsDigit(version_str[pos])) { + pos++; + } + for(idx_t version_idx = 0; version_idx < 3; version_idx++) { + idx_t digit_start = pos; + while(pos < version_str.size() && StringUtil::CharacterIsDigit(version_str[pos])) { + pos++; + } + if (digit_start == pos) { + // no digits + break; + } + // our version is at [digit_start..pos) + auto digit_str = version_str.substr(digit_start, pos - digit_start); + auto digit = std::strtoll(digit_str.c_str(), 0, 10); + switch(version_idx) { + case 0: + result.major_v = digit; + break; + case 1: + result.minor_v = digit; + break; + default: + result.patch_v = digit; + break; + } + + // check if the next character is a dot, if not we stop + if (pos >= version_str.size() || version_str[pos] != '.') { + break; + } + pos++; + } + return result; +} + } diff --git a/src/storage/postgres_catalog.cpp b/src/storage/postgres_catalog.cpp index 96043e30..ef425e06 100644 --- a/src/storage/postgres_catalog.cpp +++ b/src/storage/postgres_catalog.cpp @@ -9,8 +9,8 @@ namespace duckdb { -PostgresCatalog::PostgresCatalog(AttachedDatabase &db_p, const string &path, AccessMode access_mode) - : Catalog(db_p), path(path), access_mode(access_mode), schemas(*this) { +PostgresCatalog::PostgresCatalog(PostgresVersion version, AttachedDatabase &db_p, const string &path, AccessMode access_mode) + : Catalog(db_p), path(path), access_mode(access_mode), version(version), schemas(*this) { Value connection_limit; auto &db_instance = db_p.GetDatabase(); if (db_instance.TryGetCurrentSetting("pg_connection_limit", connection_limit)) { diff --git a/src/storage/postgres_table_entry.cpp b/src/storage/postgres_table_entry.cpp index b760984e..010bbca4 100644 --- a/src/storage/postgres_table_entry.cpp +++ b/src/storage/postgres_table_entry.cpp @@ -35,6 +35,7 @@ void PostgresTableEntry::BindUpdateConstraints(LogicalGet &, LogicalProjection & } TableFunction PostgresTableEntry::GetScanFunction(ClientContext &context, unique_ptr &bind_data) { + auto &pg_catalog = catalog.Cast(); auto &transaction = Transaction::Get(context, catalog).Cast(); auto &conn = transaction.GetConnection(); @@ -46,7 +47,7 @@ TableFunction PostgresTableEntry::GetScanFunction(ClientContext &context, unique result->transaction = &transaction; result->connection = PostgresConnection(conn.GetConnection()); - PostgresScanFunction::PrepareBind(context, *result); + PostgresScanFunction::PrepareBind(pg_catalog.GetPostgresVersion(), context, *result); for(auto &col : columns.Logical()) { result->types.push_back(col.GetType()); } @@ -57,7 +58,7 @@ TableFunction PostgresTableEntry::GetScanFunction(ClientContext &context, unique // check how many threads we can actually use if (result->max_threads > 1) { - auto &connection_pool = catalog.Cast().GetConnectionPool(); + auto &connection_pool = pg_catalog.GetConnectionPool(); result->connection_reservation = connection_pool.AllocateConnections(result->max_threads); result->max_threads = result->connection_reservation.GetConnectionCount(); } diff --git a/test/sql/storage/attach_non_existent.test b/test/sql/storage/attach_non_existent.test new file mode 100644 index 00000000..763a760c --- /dev/null +++ b/test/sql/storage/attach_non_existent.test @@ -0,0 +1,10 @@ +# name: test/sql/storage/attach_non_existent.test +# description: Test attaching to a database that does not exist +# group: [storage] + +require postgres_scanner + +statement error +ATTACH 'dbname=dbdoesnotexistx' AS s1 (TYPE POSTGRES) +---- +does not exist