diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 9dd797002..8a826cf91 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -195,6 +195,28 @@ jobs: - name: check commits run: scripts/check-each-commit.sh upstream/${{ github.base_ref }} + ####################### + # sql model generation + ####################### + sqlc-check: + name: Sqlc check + runs-on: ubuntu-latest + steps: + - name: git checkout + uses: actions/checkout@v3 + + - name: setup go ${{ env.GO_VERSION }} + uses: ./.github/actions/setup-go + with: + go-version: '${{ env.GO_VERSION }}' + + - name: docker image cache + uses: jpribyl/action-docker-layer-caching@v0.1.1 + continue-on-error: true + + - name: Generate sql models + run: make sqlc-check + ######################## # lint code ######################## diff --git a/Makefile b/Makefile index e248ac5db..648adf0bd 100644 --- a/Makefile +++ b/Makefile @@ -300,6 +300,14 @@ clean: clean-itest $(RM) ./litd-debug $(RM) coverage.txt +sqlc: + @$(call print, "Generating sql models and queries in Go") + ./scripts/gen_sqlc_docker.sh + +sqlc-check: sqlc + @$(call print, "Verifying sql code generation.") + if test -n "$$(git status --porcelain '*.go')"; then echo "SQL models not properly generated!"; git status --porcelain '*.go'; exit 1; fi + # Prevent make from interpreting any of the defined goals as folders or files to # include in the build process. .PHONY: default all yarn-install build install go-build go-build-noui \ diff --git a/db/interfaces.go b/db/interfaces.go new file mode 100644 index 000000000..3a0378f55 --- /dev/null +++ b/db/interfaces.go @@ -0,0 +1,326 @@ +package db + +import ( + "context" + "database/sql" + "math" + prand "math/rand" + "time" + + "github.com/lightninglabs/lightning-terminal/db/sqlc" +) + +var ( + // DefaultStoreTimeout is the default timeout used for any interaction + // with the storage/database. + DefaultStoreTimeout = time.Second * 10 +) + +const ( + // DefaultNumTxRetries is the default number of times we'll retry a + // transaction if it fails with an error that permits transaction + // repetition. + DefaultNumTxRetries = 10 + + // DefaultInitialRetryDelay is the default initial delay between + // retries. This will be used to generate a random delay between -50% + // and +50% of this value, so 20 to 60 milliseconds. The retry will be + // doubled after each attempt until we reach DefaultMaxRetryDelay. We + // start with a random value to avoid multiple goroutines that are + // created at the same time to effectively retry at the same time. + DefaultInitialRetryDelay = time.Millisecond * 40 + + // DefaultMaxRetryDelay is the default maximum delay between retries. + DefaultMaxRetryDelay = time.Second * 3 +) + +// TxOptions represents a set of options one can use to control what type of +// database transaction is created. Transaction can wither be read or write. +type TxOptions interface { + // ReadOnly returns true if the transaction should be read only. + ReadOnly() bool +} + +// BatchedTx is a generic interface that represents the ability to execute +// several operations to a given storage interface in a single atomic +// transaction. Typically, Q here will be some subset of the main sqlc.Querier +// interface allowing it to only depend on the routines it needs to implement +// any additional business logic. +type BatchedTx[Q any] interface { + // ExecTx will execute the passed txBody, operating upon generic + // parameter Q (usually a storage interface) in a single transaction. + // The set of TxOptions are passed in in order to allow the caller to + // specify if a transaction should be read-only and optionally what + // type of concurrency control should be used. + ExecTx(ctx context.Context, txOptions TxOptions, + txBody func(Q) error) error + + // Backend returns the type of the database backend used. + Backend() sqlc.BackendType +} + +// Tx represents a database transaction that can be committed or rolled back. +type Tx interface { + // Commit commits the database transaction, an error should be returned + // if the commit isn't possible. + Commit() error + + // Rollback rolls back an incomplete database transaction. + // Transactions that were able to be committed can still call this as a + // noop. + Rollback() error +} + +// QueryCreator is a generic function that's used to create a Querier, which is +// a type of interface that implements storage related methods from a database +// transaction. This will be used to instantiate an object callers can use to +// apply multiple modifications to an object interface in a single atomic +// transaction. +type QueryCreator[Q any] func(*sql.Tx) Q + +// BatchedQuerier is a generic interface that allows callers to create a new +// database transaction based on an abstract type that implements the TxOptions +// interface. +type BatchedQuerier interface { + // Querier is the underlying query source, this is in place so we can + // pass a BatchedQuerier implementation directly into objects that + // create a batched version of the normal methods they need. + sqlc.Querier + + // BeginTx creates a new database transaction given the set of + // transaction options. + BeginTx(ctx context.Context, options TxOptions) (*sql.Tx, error) + + // Backend returns the type of the database backend used. + Backend() sqlc.BackendType +} + +// txExecutorOptions is a struct that holds the options for the transaction +// executor. This can be used to do things like retry a transaction due to an +// error a certain amount of times. +type txExecutorOptions struct { + numRetries int + initialRetryDelay time.Duration + maxRetryDelay time.Duration +} + +// defaultTxExecutorOptions returns the default options for the transaction +// executor. +func defaultTxExecutorOptions() *txExecutorOptions { + return &txExecutorOptions{ + numRetries: DefaultNumTxRetries, + initialRetryDelay: DefaultInitialRetryDelay, + maxRetryDelay: DefaultMaxRetryDelay, + } +} + +// randRetryDelay returns a random retry delay between -50% and +50% +// of the configured delay that is doubled for each attempt and capped at a max +// value. +func (t *txExecutorOptions) randRetryDelay(attempt int) time.Duration { + halfDelay := t.initialRetryDelay / 2 + randDelay := prand.Int63n(int64(t.initialRetryDelay)) //nolint:gosec + + // 50% plus 0%-100% gives us the range of 50%-150%. + initialDelay := halfDelay + time.Duration(randDelay) + + // If this is the first attempt, we just return the initial delay. + if attempt == 0 { + return initialDelay + } + + // For each subsequent delay, we double the initial delay. This still + // gives us a somewhat random delay, but it still increases with each + // attempt. If we double something n times, that's the same as + // multiplying the value with 2^n. We limit the power to 32 to avoid + // overflows. + factor := time.Duration(math.Pow(2, math.Min(float64(attempt), 32))) + actualDelay := initialDelay * factor + + // Cap the delay at the maximum configured value. + if actualDelay > t.maxRetryDelay { + return t.maxRetryDelay + } + + return actualDelay +} + +// TxExecutorOption is a functional option that allows us to pass in optional +// argument when creating the executor. +type TxExecutorOption func(*txExecutorOptions) + +// WithTxRetries is a functional option that allows us to specify the number of +// times a transaction should be retried if it fails with a repeatable error. +func WithTxRetries(numRetries int) TxExecutorOption { + return func(o *txExecutorOptions) { + o.numRetries = numRetries + } +} + +// WithTxRetryDelay is a functional option that allows us to specify the delay +// to wait before a transaction is retried. +func WithTxRetryDelay(delay time.Duration) TxExecutorOption { + return func(o *txExecutorOptions) { + o.initialRetryDelay = delay + } +} + +// TransactionExecutor is a generic struct that abstracts away from the type of +// query a type needs to run under a database transaction, and also the set of +// options for that transaction. The QueryCreator is used to create a query +// given a database transaction created by the BatchedQuerier. +type TransactionExecutor[Query any] struct { + BatchedQuerier + + createQuery QueryCreator[Query] + + opts *txExecutorOptions +} + +// NewTransactionExecutor creates a new instance of a TransactionExecutor given +// a Querier query object and a concrete type for the type of transactions the +// Querier understands. +func NewTransactionExecutor[Querier any](db BatchedQuerier, + createQuery QueryCreator[Querier], + opts ...TxExecutorOption) *TransactionExecutor[Querier] { + + txOpts := defaultTxExecutorOptions() + for _, optFunc := range opts { + optFunc(txOpts) + } + + return &TransactionExecutor[Querier]{ + BatchedQuerier: db, + createQuery: createQuery, + opts: txOpts, + } +} + +// ExecTx is a wrapper for txBody to abstract the creation and commit of a db +// transaction. The db transaction is embedded in a `*Queries` that txBody +// needs to use when executing each one of the queries that need to be applied +// atomically. This can be used by other storage interfaces to parameterize the +// type of query and options run, in order to have access to batched operations +// related to a storage object. +func (t *TransactionExecutor[Q]) ExecTx(ctx context.Context, + txOptions TxOptions, txBody func(Q) error) error { + + waitBeforeRetry := func(attemptNumber int) { + retryDelay := t.opts.randRetryDelay(attemptNumber) + + log.Tracef("Retrying transaction due to tx serialization or "+ + "deadlock error, attempt_number=%v, delay=%v", + attemptNumber, retryDelay) + + // Before we try again, we'll wait with a random backoff based + // on the retry delay. + time.Sleep(retryDelay) + } + + for i := 0; i < t.opts.numRetries; i++ { + // Create the db transaction. + tx, err := t.BatchedQuerier.BeginTx(ctx, txOptions) + if err != nil { + dbErr := MapSQLError(err) + if IsSerializationOrDeadlockError(dbErr) { + // Nothing to roll back here, since we didn't + // even get a transaction yet. + waitBeforeRetry(i) + continue + } + + return dbErr + } + + // Rollback is safe to call even if the tx is already closed, + // so if the tx commits successfully, this is a no-op. + defer func() { + _ = tx.Rollback() + }() + + if err := txBody(t.createQuery(tx)); err != nil { + dbErr := MapSQLError(err) + if IsSerializationOrDeadlockError(dbErr) { + // Roll back the transaction, then pop back up + // to try once again. + _ = tx.Rollback() + + waitBeforeRetry(i) + continue + } + + return dbErr + } + + // Commit transaction. + if err = tx.Commit(); err != nil { + dbErr := MapSQLError(err) + if IsSerializationOrDeadlockError(dbErr) { + // Roll back the transaction, then pop back up + // to try once again. + _ = tx.Rollback() + + waitBeforeRetry(i) + continue + } + + return dbErr + } + + return nil + } + + // If we get to this point, then we weren't able to successfully commit + // a tx given the max number of retries. + return ErrRetriesExceeded +} + +// Backend returns the type of the database backend used. +func (t *TransactionExecutor[Q]) Backend() sqlc.BackendType { + return t.BatchedQuerier.Backend() +} + +// BaseDB is the base database struct that each implementation can embed to +// gain some common functionality. +type BaseDB struct { + *sql.DB + + *sqlc.Queries +} + +// BeginTx wraps the normal sql specific BeginTx method with the TxOptions +// interface. This interface is then mapped to the concrete sql tx options +// struct. +func (s *BaseDB) BeginTx(ctx context.Context, opts TxOptions) (*sql.Tx, error) { + sqlOptions := sql.TxOptions{ + ReadOnly: opts.ReadOnly(), + Isolation: sql.LevelSerializable, + } + return s.DB.BeginTx(ctx, &sqlOptions) +} + +// Backend returns the type of the database backend used. +func (s *BaseDB) Backend() sqlc.BackendType { + return s.Queries.Backend() +} + +// QueriesTxOptions defines the set of db txn options the SQLQueries +// understands. +type QueriesTxOptions struct { + // readOnly governs if a read only transaction is needed or not. + readOnly bool +} + +// ReadOnly returns true if the transaction should be read only. +// +// NOTE: This implements the TxOptions. +func (a *QueriesTxOptions) ReadOnly() bool { + return a.readOnly +} + +// NewQueryReadTx creates a new read transaction option set. +func NewQueryReadTx() QueriesTxOptions { + return QueriesTxOptions{ + readOnly: true, + } +} diff --git a/db/log.go b/db/log.go new file mode 100644 index 000000000..19c12cf48 --- /dev/null +++ b/db/log.go @@ -0,0 +1,25 @@ +package db + +import ( + "github.com/btcsuite/btclog" + "github.com/lightningnetwork/lnd/build" +) + +const Subsystem = "SQLD" + +// log is a logger that is initialized with no output filters. This +// means the package will not perform any logging by default until the caller +// requests it. +var log btclog.Logger + +// The default amount of logging is none. +func init() { + UseLogger(build.NewSubLogger(Subsystem, nil)) +} + +// UseLogger uses a specified Logger to output package logging info. +// This should be used in preference to SetLogWriter if the caller is also +// using btclog. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/db/migrations.go b/db/migrations.go new file mode 100644 index 000000000..e10426fbd --- /dev/null +++ b/db/migrations.go @@ -0,0 +1,281 @@ +package db + +import ( + "bytes" + "errors" + "fmt" + "io" + "io/fs" + "net/http" + "strings" + + "github.com/btcsuite/btclog" + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database" + "github.com/golang-migrate/migrate/v4/source/httpfs" + "github.com/lightninglabs/taproot-assets/fn" +) + +const ( + // LatestMigrationVersion is the latest migration version of the + // database. This is used to implement downgrade protection for the + // daemon. + // + // NOTE: This MUST be updated when a new migration is added. + LatestMigrationVersion = 1 +) + +// MigrationTarget is a functional option that can be passed to applyMigrations +// to specify a target version to migrate to. `currentDbVersion` is the current +// (migration) version of the database, or None if unknown. +// `maxMigrationVersion` is the maximum migration version known to the driver, +// or None if unknown. +type MigrationTarget func(mig *migrate.Migrate, + currentDbVersion int, maxMigrationVersion uint) error + +var ( + // TargetLatest is a MigrationTarget that migrates to the latest + // version available. + TargetLatest = func(mig *migrate.Migrate, _ int, _ uint) error { + return mig.Up() + } + + // TargetVersion is a MigrationTarget that migrates to the given + // version. + TargetVersion = func(version uint) MigrationTarget { + return func(mig *migrate.Migrate, _ int, _ uint) error { + return mig.Migrate(version) + } + } +) + +var ( + // ErrMigrationDowngrade is returned when a database downgrade is + // detected. + ErrMigrationDowngrade = errors.New("database downgrade detected") +) + +// migrationOption is a functional option that can be passed to migrate related +// methods to modify their behavior. +type migrateOptions struct { + latestVersion fn.Option[uint] +} + +// defaultMigrateOptions returns a new migrateOptions instance with default +// settings. +func defaultMigrateOptions() *migrateOptions { + return &migrateOptions{} +} + +// MigrateOpt is a functional option that can be passed to migrate related +// methods to modify behavior. +type MigrateOpt func(*migrateOptions) + +// WithLatestVersion allows callers to override the default latest version +// setting. +func WithLatestVersion(version uint) MigrateOpt { + return func(o *migrateOptions) { + o.latestVersion = fn.Some(version) + } +} + +// migrationLogger is a logger that wraps the passed btclog.Logger so it can be +// used to log migrations. +type migrationLogger struct { + log btclog.Logger +} + +// Printf is like fmt.Printf. We map this to the target logger based on the +// current log level. +func (m *migrationLogger) Printf(format string, v ...interface{}) { + // Trim trailing newlines from the format. + format = strings.TrimRight(format, "\n") + + switch m.log.Level() { + case btclog.LevelTrace: + m.log.Tracef(format, v...) + case btclog.LevelDebug: + m.log.Debugf(format, v...) + case btclog.LevelInfo: + m.log.Infof(format, v...) + case btclog.LevelWarn: + m.log.Warnf(format, v...) + case btclog.LevelError: + m.log.Errorf(format, v...) + case btclog.LevelCritical: + m.log.Criticalf(format, v...) + case btclog.LevelOff: + } +} + +// Verbose should return true when verbose logging output is wanted +func (m *migrationLogger) Verbose() bool { + return m.log.Level() <= btclog.LevelDebug +} + +// applyMigrations executes database migration files found in the given file +// system under the given path, using the passed database driver and database +// name, up to or down to the given target version. +func applyMigrations(fs fs.FS, driver database.Driver, path, dbName string, + targetVersion MigrationTarget, opts *migrateOptions) error { + + // With the migrate instance open, we'll create a new migration source + // using the embedded file system stored in sqlSchemas. The library + // we're using can't handle a raw file system interface, so we wrap it + // in this intermediate layer. + migrateFileServer, err := httpfs.New(http.FS(fs), path) + if err != nil { + return err + } + + // Finally, we'll run the migration with our driver above based on the + // open DB, and also the migration source stored in the file system + // above. + sqlMigrate, err := migrate.NewWithInstance( + "migrations", migrateFileServer, dbName, driver, + ) + if err != nil { + return err + } + + migrationVersion, _, _ := sqlMigrate.Version() + + // As the down migrations may end up *dropping* data, we want to + // prevent that without explicit accounting. + latestVersion := opts.latestVersion.UnwrapOr(LatestMigrationVersion) + if migrationVersion > latestVersion { + return fmt.Errorf("%w: database version is newer than the "+ + "latest migration version, preventing downgrade: "+ + "db_version=%v, latest_migration_version=%v", + ErrMigrationDowngrade, migrationVersion, latestVersion) + } + + // Report the current version of the database before the migration. + currentDbVersion, _, err := driver.Version() + if err != nil { + return fmt.Errorf("unable to get current db version: %w", err) + } + log.Infof("Attempting to apply migration(s) "+ + "(current_db_version=%v, latest_migration_version=%v)", + currentDbVersion, latestVersion) + + // Apply our local logger to the migration instance. + sqlMigrate.Log = &migrationLogger{log} + + // Execute the migration based on the target given. + err = targetVersion(sqlMigrate, currentDbVersion, latestVersion) + if err != nil && !errors.Is(err, migrate.ErrNoChange) { + return err + } + + // Report the current version of the database after the migration. + currentDbVersion, _, err = driver.Version() + if err != nil { + return fmt.Errorf("unable to get current db version: %w", err) + } + log.Infof("Database version after migration: %v", currentDbVersion) + + return nil +} + +// replacerFS is an implementation of a fs.FS virtual file system that wraps an +// existing file system but does a search-and-replace operation on each file +// when it is opened. +type replacerFS struct { + parentFS fs.FS + replaces map[string]string +} + +// A compile-time assertion to make sure replacerFS implements the fs.FS +// interface. +var _ fs.FS = (*replacerFS)(nil) + +// newReplacerFS creates a new replacer file system, wrapping the given parent +// virtual file system. Each file within the file system is undergoing a +// search-and-replace operation when it is opened, using the given map where the +// key denotes the search term and the value the term to replace each occurrence +// with. +func newReplacerFS(parent fs.FS, replaces map[string]string) *replacerFS { + return &replacerFS{ + parentFS: parent, + replaces: replaces, + } +} + +// Open opens a file in the virtual file system. +// +// NOTE: This is part of the fs.FS interface. +func (t *replacerFS) Open(name string) (fs.File, error) { + f, err := t.parentFS.Open(name) + if err != nil { + return nil, err + } + + stat, err := f.Stat() + if err != nil { + return nil, err + } + + if stat.IsDir() { + return f, err + } + + return newReplacerFile(f, t.replaces) +} + +type replacerFile struct { + parentFile fs.File + buf bytes.Buffer +} + +// A compile-time assertion to make sure replacerFile implements the fs.File +// interface. +var _ fs.File = (*replacerFile)(nil) + +func newReplacerFile(parent fs.File, replaces map[string]string) (*replacerFile, + error) { + + content, err := io.ReadAll(parent) + if err != nil { + return nil, err + } + + contentStr := string(content) + for from, to := range replaces { + contentStr = strings.ReplaceAll(contentStr, from, to) + } + + var buf bytes.Buffer + _, err = buf.WriteString(contentStr) + if err != nil { + return nil, err + } + + return &replacerFile{ + parentFile: parent, + buf: buf, + }, nil +} + +// Stat returns statistics/info about the file. +// +// NOTE: This is part of the fs.File interface. +func (t *replacerFile) Stat() (fs.FileInfo, error) { + return t.parentFile.Stat() +} + +// Read reads as many bytes as possible from the file into the given slice. +// +// NOTE: This is part of the fs.File interface. +func (t *replacerFile) Read(bytes []byte) (int, error) { + return t.buf.Read(bytes) +} + +// Close closes the underlying file. +// +// NOTE: This is part of the fs.File interface. +func (t *replacerFile) Close() error { + // We already fully read and then closed the file when creating this + // instance, so there's nothing to do for us here. + return nil +} diff --git a/db/postgres.go b/db/postgres.go new file mode 100644 index 000000000..16e41dc09 --- /dev/null +++ b/db/postgres.go @@ -0,0 +1,207 @@ +package db + +import ( + "database/sql" + "fmt" + "testing" + "time" + + postgres_migrate "github.com/golang-migrate/migrate/v4/database/postgres" + _ "github.com/golang-migrate/migrate/v4/source/file" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/stretchr/testify/require" +) + +const ( + dsnTemplate = "postgres://%v:%v@%v:%d/%v?sslmode=%v" + + // defaultMaxIdleConns is the number of permitted idle connections. + defaultMaxIdleConns = 6 + + // defaultConnMaxIdleTime is the amount of time a connection can be + // idle before it is closed. + defaultConnMaxIdleTime = 5 * time.Minute +) + +var ( + // DefaultPostgresFixtureLifetime is the default maximum time a Postgres + // test fixture is being kept alive. After that time the docker + // container will be terminated forcefully, even if the tests aren't + // fully executed yet. So this time needs to be chosen correctly to be + // longer than the longest expected individual test run time. + DefaultPostgresFixtureLifetime = 60 * time.Minute + + // postgresSchemaReplacements is a map of schema strings that need to be + // replaced for postgres. This is needed because we write the schemas + // to work with sqlite primarily, and postgres has some differences. + postgresSchemaReplacements = map[string]string{ + "BLOB": "BYTEA", + "INTEGER PRIMARY KEY": "BIGSERIAL PRIMARY KEY", + "TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE", + "UNHEX": "DECODE", + } +) + +// PostgresConfig holds the postgres database configuration. +// +// nolint:lll +type PostgresConfig struct { + SkipMigrations bool `long:"skipmigrations" description:"Skip applying migrations on startup."` + Host string `long:"host" description:"Database server hostname."` + Port int `long:"port" description:"Database server port."` + User string `long:"user" description:"Database user."` + Password string `long:"password" description:"Database user's password."` + DBName string `long:"dbname" description:"Database name to use."` + MaxOpenConnections int `long:"maxconnections" description:"Max open connections to keep alive to the database server."` + MaxIdleConnections int `long:"maxidleconnections" description:"Max number of idle connections to keep in the connection pool."` + ConnMaxLifetime time.Duration `long:"connmaxlifetime" description:"Max amount of time a connection can be reused for before it is closed. Valid time units are {s, m, h}."` + ConnMaxIdleTime time.Duration `long:"connmaxidletime" description:"Max amount of time a connection can be idle for before it is closed. Valid time units are {s, m, h}."` + RequireSSL bool `long:"requiressl" description:"Whether to require using SSL (mode: require) when connecting to the server."` +} + +// DSN returns the dns to connect to the database. +func (s *PostgresConfig) DSN(hidePassword bool) string { + var sslMode = "disable" + if s.RequireSSL { + sslMode = "require" + } + + password := s.Password + if hidePassword { + // Placeholder used for logging the DSN safely. + password = "****" + } + + return fmt.Sprintf(dsnTemplate, s.User, password, s.Host, s.Port, + s.DBName, sslMode) +} + +// PostgresStore is a database store implementation that uses a Postgres +// backend. +type PostgresStore struct { + cfg *PostgresConfig + + *BaseDB +} + +// NewPostgresStore creates a new store that is backed by a Postgres database +// backend. +func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) { + log.Infof("Using SQL database '%s'", cfg.DSN(true)) + + rawDb, err := sql.Open("pgx", cfg.DSN(false)) + if err != nil { + return nil, err + } + + maxConns := defaultMaxConns + if cfg.MaxOpenConnections > 0 { + maxConns = cfg.MaxOpenConnections + } + + maxIdleConns := defaultMaxIdleConns + if cfg.MaxIdleConnections > 0 { + maxIdleConns = cfg.MaxIdleConnections + } + + connMaxLifetime := defaultConnMaxLifetime + if cfg.ConnMaxLifetime > 0 { + connMaxLifetime = cfg.ConnMaxLifetime + } + + connMaxIdleTime := defaultConnMaxIdleTime + if cfg.ConnMaxIdleTime > 0 { + connMaxIdleTime = cfg.ConnMaxIdleTime + } + + rawDb.SetMaxOpenConns(maxConns) + rawDb.SetMaxIdleConns(maxIdleConns) + rawDb.SetConnMaxLifetime(connMaxLifetime) + rawDb.SetConnMaxIdleTime(connMaxIdleTime) + + queries := sqlc.NewPostgres(rawDb) + s := &PostgresStore{ + cfg: cfg, + BaseDB: &BaseDB{ + DB: rawDb, + Queries: queries, + }, + } + + // Now that the database is open, populate the database with our set of + // schemas based on our embedded in-memory file system. + if !cfg.SkipMigrations { + if err := s.ExecuteMigrations(TargetLatest); err != nil { + return nil, fmt.Errorf("error executing migrations: "+ + "%w", err) + } + } + + return s, nil +} + +// ExecuteMigrations runs migrations for the Postgres database, depending on the +// target given, either all migrations or up to a given version. +func (s *PostgresStore) ExecuteMigrations(target MigrationTarget, + optFuncs ...MigrateOpt) error { + + opts := defaultMigrateOptions() + for _, optFunc := range optFuncs { + optFunc(opts) + } + + driver, err := postgres_migrate.WithInstance( + s.DB, &postgres_migrate.Config{}, + ) + if err != nil { + return fmt.Errorf("error creating postgres migration: %w", err) + } + + postgresFS := newReplacerFS(sqlSchemas, postgresSchemaReplacements) + return applyMigrations( + postgresFS, driver, "sqlc/migrations", s.cfg.DBName, target, + opts, + ) +} + +// NewTestPostgresDB is a helper function that creates a Postgres database for +// testing. +func NewTestPostgresDB(t *testing.T) *PostgresStore { + t.Helper() + + t.Logf("Creating new Postgres DB for testing") + + sqlFixture := NewTestPgFixture(t, DefaultPostgresFixtureLifetime, true) + store, err := NewPostgresStore(sqlFixture.GetConfig()) + require.NoError(t, err) + + t.Cleanup(func() { + sqlFixture.TearDown(t) + }) + + return store +} + +// NewTestPostgresDBWithVersion is a helper function that creates a Postgres +// database for testing and migrates it to the given version. +func NewTestPostgresDBWithVersion(t *testing.T, version uint) *PostgresStore { + t.Helper() + + t.Logf("Creating new Postgres DB for testing, migrating to version %d", + version) + + sqlFixture := NewTestPgFixture(t, DefaultPostgresFixtureLifetime, true) + storeCfg := sqlFixture.GetConfig() + storeCfg.SkipMigrations = true + store, err := NewPostgresStore(storeCfg) + require.NoError(t, err) + + err = store.ExecuteMigrations(TargetVersion(version)) + require.NoError(t, err) + + t.Cleanup(func() { + sqlFixture.TearDown(t) + }) + + return store +} diff --git a/db/postgres_fixture.go b/db/postgres_fixture.go new file mode 100644 index 000000000..7a72a8ea1 --- /dev/null +++ b/db/postgres_fixture.go @@ -0,0 +1,141 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + "strconv" + "strings" + "testing" + "time" + + _ "github.com/lib/pq" + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" + "github.com/stretchr/testify/require" +) + +const ( + testPgUser = "test" + testPgPass = "test" + testPgDBName = "test" + PostgresTag = "15" +) + +// TestPgFixture is a test fixture that starts a Postgres 11 instance in a +// docker container. +type TestPgFixture struct { + db *sql.DB + pool *dockertest.Pool + resource *dockertest.Resource + host string + port int +} + +// NewTestPgFixture constructs a new TestPgFixture starting up a docker +// container running Postgres 11. The started container will expire in after +// the passed duration. +func NewTestPgFixture(t *testing.T, expiry time.Duration, + autoRemove bool) *TestPgFixture { + + // Use a sensible default on Windows (tcp/http) and linux/osx (socket) + // by specifying an empty endpoint. + pool, err := dockertest.NewPool("") + require.NoError(t, err, "Could not connect to docker") + + // Pulls an image, creates a container based on it and runs it. + resource, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: PostgresTag, + Env: []string{ + fmt.Sprintf("POSTGRES_USER=%v", testPgUser), + fmt.Sprintf("POSTGRES_PASSWORD=%v", testPgPass), + fmt.Sprintf("POSTGRES_DB=%v", testPgDBName), + "listen_addresses='*'", + }, + Cmd: []string{ + "postgres", + "-c", "log_statement=all", + "-c", "log_destination=stderr", + }, + }, func(config *docker.HostConfig) { + // Set AutoRemove to true so that stopped container goes away + // by itself, unless we want to keep it around for debugging. + config.AutoRemove = autoRemove + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + require.NoError(t, err, "Could not start resource") + + hostAndPort := resource.GetHostPort("5432/tcp") + parts := strings.Split(hostAndPort, ":") + host := parts[0] + port, err := strconv.ParseInt(parts[1], 10, 64) + require.NoError(t, err) + + fixture := &TestPgFixture{ + host: host, + port: int(port), + } + databaseURL := fixture.GetDSN() + log.Infof("Connecting to Postgres fixture: %v\n", databaseURL) + + // Tell docker to hard kill the container in "expiry" seconds. + require.NoError(t, resource.Expire(uint(expiry.Seconds()))) + + // Exponential backoff-retry, because the application in the container + // might not be ready to accept connections yet. + pool.MaxWait = 120 * time.Second + + var testDB *sql.DB + err = pool.Retry(func() error { + testDB, err = sql.Open("postgres", databaseURL) + if err != nil { + return err + } + return testDB.Ping() + }) + require.NoError(t, err, "Could not connect to docker") + + // Now fill in the rest of the fixture. + fixture.db = testDB + fixture.pool = pool + fixture.resource = resource + + return fixture +} + +// GetDSN returns the DSN (Data Source Name) for the started Postgres node. +func (f *TestPgFixture) GetDSN() string { + return f.GetConfig().DSN(false) +} + +// GetConfig returns the full config of the Postgres node. +func (f *TestPgFixture) GetConfig() *PostgresConfig { + return &PostgresConfig{ + Host: f.host, + Port: f.port, + User: testPgUser, + Password: testPgPass, + DBName: testPgDBName, + RequireSSL: false, + } +} + +// TearDown stops the underlying docker container. +func (f *TestPgFixture) TearDown(t *testing.T) { + err := f.pool.Purge(f.resource) + require.NoError(t, err, "Could not purge resource") +} + +// ClearDB clears the database. +func (f *TestPgFixture) ClearDB(t *testing.T) { + dbConn, err := sql.Open("postgres", f.GetDSN()) + require.NoError(t, err) + + _, err = dbConn.ExecContext( + context.Background(), + `DROP SCHEMA IF EXISTS public CASCADE; + CREATE SCHEMA public;`, + ) + require.NoError(t, err) +} diff --git a/db/schemas.go b/db/schemas.go new file mode 100644 index 000000000..1a7a2096f --- /dev/null +++ b/db/schemas.go @@ -0,0 +1,9 @@ +package db + +import ( + "embed" + _ "embed" +) + +//go:embed sqlc/migrations/*.*.sql +var sqlSchemas embed.FS diff --git a/db/sqlc/accounts.sql.go b/db/sqlc/accounts.sql.go new file mode 100644 index 000000000..4deefdb88 --- /dev/null +++ b/db/sqlc/accounts.sql.go @@ -0,0 +1,394 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 +// source: accounts.sql + +package sqlc + +import ( + "context" + "database/sql" + "time" +) + +const addAccountInvoice = `-- name: AddAccountInvoice :exec +INSERT INTO account_invoices (account_id, hash) +VALUES ($1, $2) +` + +type AddAccountInvoiceParams struct { + AccountID int64 + Hash []byte +} + +func (q *Queries) AddAccountInvoice(ctx context.Context, arg AddAccountInvoiceParams) error { + _, err := q.db.ExecContext(ctx, addAccountInvoice, arg.AccountID, arg.Hash) + return err +} + +const deleteAccount = `-- name: DeleteAccount :exec +DELETE FROM accounts +WHERE id = $1 +` + +func (q *Queries) DeleteAccount(ctx context.Context, id int64) error { + _, err := q.db.ExecContext(ctx, deleteAccount, id) + return err +} + +const deleteAccountPayment = `-- name: DeleteAccountPayment :exec +DELETE FROM account_payments +WHERE hash = $1 +AND account_id = $2 +` + +type DeleteAccountPaymentParams struct { + Hash []byte + AccountID int64 +} + +func (q *Queries) DeleteAccountPayment(ctx context.Context, arg DeleteAccountPaymentParams) error { + _, err := q.db.ExecContext(ctx, deleteAccountPayment, arg.Hash, arg.AccountID) + return err +} + +const getAccount = `-- name: GetAccount :one +SELECT id, alias, label, type, initial_balance_msat, current_balance_msat, last_updated, expiration +FROM accounts +WHERE id = $1 +` + +func (q *Queries) GetAccount(ctx context.Context, id int64) (Account, error) { + row := q.db.QueryRowContext(ctx, getAccount, id) + var i Account + err := row.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.Type, + &i.InitialBalanceMsat, + &i.CurrentBalanceMsat, + &i.LastUpdated, + &i.Expiration, + ) + return i, err +} + +const getAccountByLabel = `-- name: GetAccountByLabel :one +SELECT id, alias, label, type, initial_balance_msat, current_balance_msat, last_updated, expiration +FROM accounts +WHERE label = $1 +` + +func (q *Queries) GetAccountByLabel(ctx context.Context, label sql.NullString) (Account, error) { + row := q.db.QueryRowContext(ctx, getAccountByLabel, label) + var i Account + err := row.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.Type, + &i.InitialBalanceMsat, + &i.CurrentBalanceMsat, + &i.LastUpdated, + &i.Expiration, + ) + return i, err +} + +const getAccountIDByAlias = `-- name: GetAccountIDByAlias :one +SELECT id +FROM accounts +WHERE alias = $1 +` + +func (q *Queries) GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) { + row := q.db.QueryRowContext(ctx, getAccountIDByAlias, alias) + var id int64 + err := row.Scan(&id) + return id, err +} + +const getAccountIndex = `-- name: GetAccountIndex :one +SELECT value +FROM account_indices +WHERE name = $1 +` + +func (q *Queries) GetAccountIndex(ctx context.Context, name string) (int64, error) { + row := q.db.QueryRowContext(ctx, getAccountIndex, name) + var value int64 + err := row.Scan(&value) + return value, err +} + +const getAccountInvoice = `-- name: GetAccountInvoice :one +SELECT account_id, hash +FROM account_invoices +WHERE account_id = $1 + AND hash = $2 +` + +type GetAccountInvoiceParams struct { + AccountID int64 + Hash []byte +} + +func (q *Queries) GetAccountInvoice(ctx context.Context, arg GetAccountInvoiceParams) (AccountInvoice, error) { + row := q.db.QueryRowContext(ctx, getAccountInvoice, arg.AccountID, arg.Hash) + var i AccountInvoice + err := row.Scan(&i.AccountID, &i.Hash) + return i, err +} + +const getAccountPayment = `-- name: GetAccountPayment :one +SELECT account_id, hash, status, full_amount_msat FROM account_payments +WHERE hash = $1 +AND account_id = $2 +` + +type GetAccountPaymentParams struct { + Hash []byte + AccountID int64 +} + +func (q *Queries) GetAccountPayment(ctx context.Context, arg GetAccountPaymentParams) (AccountPayment, error) { + row := q.db.QueryRowContext(ctx, getAccountPayment, arg.Hash, arg.AccountID) + var i AccountPayment + err := row.Scan( + &i.AccountID, + &i.Hash, + &i.Status, + &i.FullAmountMsat, + ) + return i, err +} + +const insertAccount = `-- name: InsertAccount :one +INSERT INTO accounts (type, initial_balance_msat, current_balance_msat, last_updated, label, alias, expiration) +VALUES ($1, $2, $3, $4, $5, $6, $7) + RETURNING id +` + +type InsertAccountParams struct { + Type int16 + InitialBalanceMsat int64 + CurrentBalanceMsat int64 + LastUpdated time.Time + Label sql.NullString + Alias int64 + Expiration time.Time +} + +func (q *Queries) InsertAccount(ctx context.Context, arg InsertAccountParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertAccount, + arg.Type, + arg.InitialBalanceMsat, + arg.CurrentBalanceMsat, + arg.LastUpdated, + arg.Label, + arg.Alias, + arg.Expiration, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + +const listAccountInvoices = `-- name: ListAccountInvoices :many +SELECT account_id, hash +FROM account_invoices +WHERE account_id = $1 +` + +func (q *Queries) ListAccountInvoices(ctx context.Context, accountID int64) ([]AccountInvoice, error) { + rows, err := q.db.QueryContext(ctx, listAccountInvoices, accountID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []AccountInvoice + for rows.Next() { + var i AccountInvoice + if err := rows.Scan(&i.AccountID, &i.Hash); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAccountPayments = `-- name: ListAccountPayments :many +SELECT account_id, hash, status, full_amount_msat +FROM account_payments +WHERE account_id = $1 +` + +func (q *Queries) ListAccountPayments(ctx context.Context, accountID int64) ([]AccountPayment, error) { + rows, err := q.db.QueryContext(ctx, listAccountPayments, accountID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []AccountPayment + for rows.Next() { + var i AccountPayment + if err := rows.Scan( + &i.AccountID, + &i.Hash, + &i.Status, + &i.FullAmountMsat, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAllAccounts = `-- name: ListAllAccounts :many +SELECT id, alias, label, type, initial_balance_msat, current_balance_msat, last_updated, expiration +FROM accounts +` + +func (q *Queries) ListAllAccounts(ctx context.Context) ([]Account, error) { + rows, err := q.db.QueryContext(ctx, listAllAccounts) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Account + for rows.Next() { + var i Account + if err := rows.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.Type, + &i.InitialBalanceMsat, + &i.CurrentBalanceMsat, + &i.LastUpdated, + &i.Expiration, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const setAccountIndex = `-- name: SetAccountIndex :exec +INSERT INTO account_indices (name, value) +VALUES ($1, $2) + ON CONFLICT (name) +DO UPDATE SET value = $2 +` + +type SetAccountIndexParams struct { + Name string + Value int64 +} + +func (q *Queries) SetAccountIndex(ctx context.Context, arg SetAccountIndexParams) error { + _, err := q.db.ExecContext(ctx, setAccountIndex, arg.Name, arg.Value) + return err +} + +const updateAccountBalance = `-- name: UpdateAccountBalance :one +UPDATE accounts +SET current_balance_msat = $1 +WHERE id = $2 +RETURNING id +` + +type UpdateAccountBalanceParams struct { + CurrentBalanceMsat int64 + ID int64 +} + +func (q *Queries) UpdateAccountBalance(ctx context.Context, arg UpdateAccountBalanceParams) (int64, error) { + row := q.db.QueryRowContext(ctx, updateAccountBalance, arg.CurrentBalanceMsat, arg.ID) + var id int64 + err := row.Scan(&id) + return id, err +} + +const updateAccountExpiry = `-- name: UpdateAccountExpiry :one +UPDATE accounts +SET expiration = $1 +WHERE id = $2 +RETURNING id +` + +type UpdateAccountExpiryParams struct { + Expiration time.Time + ID int64 +} + +func (q *Queries) UpdateAccountExpiry(ctx context.Context, arg UpdateAccountExpiryParams) (int64, error) { + row := q.db.QueryRowContext(ctx, updateAccountExpiry, arg.Expiration, arg.ID) + var id int64 + err := row.Scan(&id) + return id, err +} + +const updateAccountLastUpdate = `-- name: UpdateAccountLastUpdate :one +UPDATE accounts +SET last_updated = $1 +WHERE id = $2 +RETURNING id +` + +type UpdateAccountLastUpdateParams struct { + LastUpdated time.Time + ID int64 +} + +func (q *Queries) UpdateAccountLastUpdate(ctx context.Context, arg UpdateAccountLastUpdateParams) (int64, error) { + row := q.db.QueryRowContext(ctx, updateAccountLastUpdate, arg.LastUpdated, arg.ID) + var id int64 + err := row.Scan(&id) + return id, err +} + +const upsertAccountPayment = `-- name: UpsertAccountPayment :exec +INSERT INTO account_payments (account_id, hash, status, full_amount_msat) +VALUES ($1, $2, $3, $4) +ON CONFLICT (account_id, hash) +DO UPDATE SET status = $3, full_amount_msat = $4 +` + +type UpsertAccountPaymentParams struct { + AccountID int64 + Hash []byte + Status int16 + FullAmountMsat int64 +} + +func (q *Queries) UpsertAccountPayment(ctx context.Context, arg UpsertAccountPaymentParams) error { + _, err := q.db.ExecContext(ctx, upsertAccountPayment, + arg.AccountID, + arg.Hash, + arg.Status, + arg.FullAmountMsat, + ) + return err +} diff --git a/db/sqlc/db.go b/db/sqlc/db.go new file mode 100644 index 000000000..8ed64d139 --- /dev/null +++ b/db/sqlc/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 + +package sqlc + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/db/sqlc/db_custom.go b/db/sqlc/db_custom.go new file mode 100644 index 000000000..f9e70033b --- /dev/null +++ b/db/sqlc/db_custom.go @@ -0,0 +1,46 @@ +package sqlc + +// BackendType is an enum that represents the type of database backend we're +// using. +type BackendType uint8 + +const ( + // BackendTypeUnknown indicates we're using an unknown backend. + BackendTypeUnknown BackendType = iota + + // BackendTypeSqlite indicates we're using a SQLite backend. + BackendTypeSqlite + + // BackendTypePostgres indicates we're using a Postgres backend. + BackendTypePostgres +) + +// wrappedTX is a wrapper around a DBTX that also stores the database backend +// type. +type wrappedTX struct { + DBTX + + backendType BackendType +} + +// Backend returns the type of database backend we're using. +func (q *Queries) Backend() BackendType { + wtx, ok := q.db.(*wrappedTX) + if !ok { + // Shouldn't happen unless a new database backend type is added + // but not initialized correctly. + return BackendTypeUnknown + } + + return wtx.backendType +} + +// NewSqlite creates a new Queries instance for a SQLite database. +func NewSqlite(db DBTX) *Queries { + return &Queries{db: &wrappedTX{db, BackendTypeSqlite}} +} + +// NewPostgres creates a new Queries instance for a Postgres database. +func NewPostgres(db DBTX) *Queries { + return &Queries{db: &wrappedTX{db, BackendTypePostgres}} +} diff --git a/db/sqlc/migrations/000001_accounts.down.sql b/db/sqlc/migrations/000001_accounts.down.sql new file mode 100644 index 000000000..da81c5dfe --- /dev/null +++ b/db/sqlc/migrations/000001_accounts.down.sql @@ -0,0 +1,4 @@ +DROP TABLE IF EXISTS account_payments; +DROP TABLE IF EXISTS account_invoices; +DROP TABLE IF EXISTS account_indices; +DROP TABLE IF EXISTS accounts; diff --git a/db/sqlc/migrations/000001_accounts.up.sql b/db/sqlc/migrations/000001_accounts.up.sql new file mode 100644 index 000000000..f9e17e168 --- /dev/null +++ b/db/sqlc/migrations/000001_accounts.up.sql @@ -0,0 +1,71 @@ +CREATE TABLE IF NOT EXISTS accounts ( + -- The auto incrementing primary key. + id INTEGER PRIMARY KEY, + + -- The ID that was used to identify the account in the legacy KVDB store. + -- In order to avoid breaking the API, we keep this field here so that + -- we can still look up accounts by this ID for the time being. + alias BIGINT NOT NULL UNIQUE, + + -- An optional label to use for the account. If it is set, it must be + -- unique. + label TEXT UNIQUE, + + -- The account type. + type SMALLINT NOT NULL, + + -- The accounts initial balance. This is never updated. + initial_balance_msat BIGINT NOT NULL, + + -- The accounts current balance. This is updated as the account is used. + current_balance_msat BIGINT NOT NULL, + + -- The last time the account was updated. + last_updated TIMESTAMP NOT NULL, + + -- The time that the account will expire. + expiration TIMESTAMP NOT NULL +); + +-- The account_payments table stores all the payment hashes of outgoing +-- payments that are associated with a particular account. These are used to +-- when an account should be debited. +CREATE TABLE IF NOT EXISTS account_payments ( + -- The account that this payment is linked to. + account_id BIGINT NOT NULL REFERENCES accounts(id) ON DELETE CASCADE, + + -- The payment hash of the payment. + hash BLOB NOT NULL, + + -- The LND RPC status of the payment. + status SMALLINT NOT NULL, + + -- The total amount of the payment in millisatoshis. + -- This includes the payment amount and estimated routing fee. + full_amount_msat BIGINT NOT NULL, + + UNIQUE(account_id, hash) +); + +-- The account_invoices table stores all the invoice payment hashes that +-- are associated with a particular account. These are used to determine +-- when an account should be credited. +CREATE TABLE IF NOT EXISTS account_invoices ( + -- The account that this invoice is linked to. + account_id BIGINT NOT NULL REFERENCES accounts(id) ON DELETE CASCADE, + + -- The payment hash of the invoice. + hash BLOB NOT NULL, + + UNIQUE(account_id, hash) +); + +-- The account_indices table stores any string-to-integer mappings that are +-- used by the accounts system. +CREATE TABLE IF NOT EXISTS account_indices ( + -- The unique name of the index. + name TEXT NOT NULL UNIQUE, + + -- The current value of the index. + value BIGINT NOT NULL +); \ No newline at end of file diff --git a/db/sqlc/models.go b/db/sqlc/models.go new file mode 100644 index 000000000..f11698e87 --- /dev/null +++ b/db/sqlc/models.go @@ -0,0 +1,38 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 + +package sqlc + +import ( + "database/sql" + "time" +) + +type Account struct { + ID int64 + Alias int64 + Label sql.NullString + Type int16 + InitialBalanceMsat int64 + CurrentBalanceMsat int64 + LastUpdated time.Time + Expiration time.Time +} + +type AccountIndex struct { + Name string + Value int64 +} + +type AccountInvoice struct { + AccountID int64 + Hash []byte +} + +type AccountPayment struct { + AccountID int64 + Hash []byte + Status int16 + FullAmountMsat int64 +} diff --git a/db/sqlc/querier.go b/db/sqlc/querier.go new file mode 100644 index 000000000..b0265c596 --- /dev/null +++ b/db/sqlc/querier.go @@ -0,0 +1,33 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 + +package sqlc + +import ( + "context" + "database/sql" +) + +type Querier interface { + AddAccountInvoice(ctx context.Context, arg AddAccountInvoiceParams) error + DeleteAccount(ctx context.Context, id int64) error + DeleteAccountPayment(ctx context.Context, arg DeleteAccountPaymentParams) error + GetAccount(ctx context.Context, id int64) (Account, error) + GetAccountByLabel(ctx context.Context, label sql.NullString) (Account, error) + GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) + GetAccountIndex(ctx context.Context, name string) (int64, error) + GetAccountInvoice(ctx context.Context, arg GetAccountInvoiceParams) (AccountInvoice, error) + GetAccountPayment(ctx context.Context, arg GetAccountPaymentParams) (AccountPayment, error) + InsertAccount(ctx context.Context, arg InsertAccountParams) (int64, error) + ListAccountInvoices(ctx context.Context, accountID int64) ([]AccountInvoice, error) + ListAccountPayments(ctx context.Context, accountID int64) ([]AccountPayment, error) + ListAllAccounts(ctx context.Context) ([]Account, error) + SetAccountIndex(ctx context.Context, arg SetAccountIndexParams) error + UpdateAccountBalance(ctx context.Context, arg UpdateAccountBalanceParams) (int64, error) + UpdateAccountExpiry(ctx context.Context, arg UpdateAccountExpiryParams) (int64, error) + UpdateAccountLastUpdate(ctx context.Context, arg UpdateAccountLastUpdateParams) (int64, error) + UpsertAccountPayment(ctx context.Context, arg UpsertAccountPaymentParams) error +} + +var _ Querier = (*Queries)(nil) diff --git a/db/sqlc/queries/accounts.sql b/db/sqlc/queries/accounts.sql new file mode 100644 index 000000000..637a49727 --- /dev/null +++ b/db/sqlc/queries/accounts.sql @@ -0,0 +1,92 @@ +-- name: InsertAccount :one +INSERT INTO accounts (type, initial_balance_msat, current_balance_msat, last_updated, label, alias, expiration) +VALUES ($1, $2, $3, $4, $5, $6, $7) + RETURNING id; + +-- name: UpdateAccountBalance :one +UPDATE accounts +SET current_balance_msat = $1 +WHERE id = $2 +RETURNING id; + +-- name: UpdateAccountExpiry :one +UPDATE accounts +SET expiration = $1 +WHERE id = $2 +RETURNING id; + +-- name: UpdateAccountLastUpdate :one +UPDATE accounts +SET last_updated = $1 +WHERE id = $2 +RETURNING id; + +-- name: AddAccountInvoice :exec +INSERT INTO account_invoices (account_id, hash) +VALUES ($1, $2); + +-- name: DeleteAccountPayment :exec +DELETE FROM account_payments +WHERE hash = $1 +AND account_id = $2; + +-- name: UpsertAccountPayment :exec +INSERT INTO account_payments (account_id, hash, status, full_amount_msat) +VALUES ($1, $2, $3, $4) +ON CONFLICT (account_id, hash) +DO UPDATE SET status = $3, full_amount_msat = $4; + +-- name: GetAccountPayment :one +SELECT * FROM account_payments +WHERE hash = $1 +AND account_id = $2; + +-- name: GetAccount :one +SELECT * +FROM accounts +WHERE id = $1; + +-- name: GetAccountIDByAlias :one +SELECT id +FROM accounts +WHERE alias = $1; + +-- name: GetAccountByLabel :one +SELECT * +FROM accounts +WHERE label = $1; + +-- name: DeleteAccount :exec +DELETE FROM accounts +WHERE id = $1; + +-- name: ListAllAccounts :many +SELECT * +FROM accounts; + +-- name: ListAccountPayments :many +SELECT * +FROM account_payments +WHERE account_id = $1; + +-- name: ListAccountInvoices :many +SELECT * +FROM account_invoices +WHERE account_id = $1; + +-- name: GetAccountInvoice :one +SELECT * +FROM account_invoices +WHERE account_id = $1 + AND hash = $2; + +-- name: SetAccountIndex :exec +INSERT INTO account_indices (name, value) +VALUES ($1, $2) + ON CONFLICT (name) +DO UPDATE SET value = $2; + +-- name: GetAccountIndex :one +SELECT value +FROM account_indices +WHERE name = $1; diff --git a/db/sqlerrors.go b/db/sqlerrors.go new file mode 100644 index 000000000..1116bbe9b --- /dev/null +++ b/db/sqlerrors.go @@ -0,0 +1,198 @@ +package db + +import ( + "errors" + "fmt" + "strings" + + "github.com/jackc/pgconn" + "github.com/jackc/pgerrcode" + "modernc.org/sqlite" + sqlite3 "modernc.org/sqlite/lib" +) + +var ( + // ErrRetriesExceeded is returned when a transaction is retried more + // than the max allowed valued without a success. + ErrRetriesExceeded = errors.New("db tx retries exceeded") +) + +// MapSQLError attempts to interpret a given error as a database agnostic SQL +// error. +func MapSQLError(err error) error { + // Attempt to interpret the error as a sqlite error. + var sqliteErr *sqlite.Error + if errors.As(err, &sqliteErr) { + return parseSqliteError(sqliteErr) + } + + // Attempt to interpret the error as a postgres error. + var pqErr *pgconn.PgError + if errors.As(err, &pqErr) { + return parsePostgresError(pqErr) + } + + // Return original error if it could not be classified as a database + // specific error. + return err +} + +// parseSqliteError attempts to parse a sqlite error as a database agnostic +// SQL error. +func parseSqliteError(sqliteErr *sqlite.Error) error { + switch sqliteErr.Code() { + // Handle unique constraint violation error. + case sqlite3.SQLITE_CONSTRAINT_UNIQUE: + return &ErrSqlUniqueConstraintViolation{ + DbError: sqliteErr, + } + + // Database is currently busy, so we'll need to try again. + case sqlite3.SQLITE_BUSY: + return &ErrSerializationError{ + DbError: sqliteErr, + } + + // A write operation could not continue because of a conflict within the + // same database connection. + case sqlite3.SQLITE_LOCKED: + return &ErrDeadlockError{ + DbError: sqliteErr, + } + + // Generic error, need to parse the message further. + case sqlite3.SQLITE_ERROR: + errMsg := sqliteErr.Error() + + switch { + case strings.Contains(errMsg, "no such table"): + return &ErrSchemaError{ + DbError: sqliteErr, + } + + default: + return fmt.Errorf("unknown sqlite error: %w", sqliteErr) + } + + default: + return fmt.Errorf("unknown sqlite error: %w", sqliteErr) + } +} + +// parsePostgresError attempts to parse a postgres error as a database agnostic +// SQL error. +func parsePostgresError(pqErr *pgconn.PgError) error { + switch pqErr.Code { + // Handle unique constraint violation error. + case pgerrcode.UniqueViolation: + return &ErrSqlUniqueConstraintViolation{ + DbError: pqErr, + } + + // Unable to serialize the transaction, so we'll need to try again. + case pgerrcode.SerializationFailure: + return &ErrSerializationError{ + DbError: pqErr, + } + + // A write operation could not continue because of a conflict within the + // same database connection. + case pgerrcode.DeadlockDetected: + return &ErrDeadlockError{ + DbError: pqErr, + } + + // Handle schema error. + case pgerrcode.UndefinedColumn, pgerrcode.UndefinedTable: + return &ErrSchemaError{ + DbError: pqErr, + } + + default: + return fmt.Errorf("unknown postgres error: %w", pqErr) + } +} + +// ErrSqlUniqueConstraintViolation is an error type which represents a database +// agnostic SQL unique constraint violation. +type ErrSqlUniqueConstraintViolation struct { + DbError error +} + +func (e ErrSqlUniqueConstraintViolation) Error() string { + return fmt.Sprintf("sql unique constraint violation: %v", e.DbError) +} + +// ErrSerializationError is an error type which represents a database agnostic +// error that a transaction couldn't be serialized with other concurrent db +// transactions. +type ErrSerializationError struct { + DbError error +} + +// Unwrap returns the wrapped error. +func (e ErrSerializationError) Unwrap() error { + return e.DbError +} + +// Error returns the error message. +func (e ErrSerializationError) Error() string { + return e.DbError.Error() +} + +// ErrDeadlockError is an error type which represents a database agnostic +// error where transactions have led to cyclic dependencies in lock acquisition. +type ErrDeadlockError struct { + DbError error +} + +// Unwrap returns the wrapped error. +func (e ErrDeadlockError) Unwrap() error { + return e.DbError +} + +// Error returns the error message. +func (e ErrDeadlockError) Error() string { + return e.DbError.Error() +} + +// IsSerializationError returns true if the given error is a serialization +// error. +func IsSerializationError(err error) bool { + var serializationError *ErrSerializationError + return errors.As(err, &serializationError) +} + +// IsDeadlockError returns true if the given error is a deadlock error. +func IsDeadlockError(err error) bool { + var deadlockError *ErrDeadlockError + return errors.As(err, &deadlockError) +} + +// IsSerializationOrDeadlockError returns true if the given error is either a +// deadlock error or a serialization error. +func IsSerializationOrDeadlockError(err error) bool { + return IsDeadlockError(err) || IsSerializationError(err) +} + +// ErrSchemaError is an error type which represents a database agnostic error +// that the schema of the database is incorrect for the given query. +type ErrSchemaError struct { + DbError error +} + +// Unwrap returns the wrapped error. +func (e ErrSchemaError) Unwrap() error { + return e.DbError +} + +// Error returns the error message. +func (e ErrSchemaError) Error() string { + return e.DbError.Error() +} + +// IsSchemaError returns true if the given error is a schema error. +func IsSchemaError(err error) bool { + var schemaError *ErrSchemaError + return errors.As(err, &schemaError) +} diff --git a/db/sqlite.go b/db/sqlite.go new file mode 100644 index 000000000..803362fa8 --- /dev/null +++ b/db/sqlite.go @@ -0,0 +1,308 @@ +package db + +import ( + "database/sql" + "fmt" + "net/url" + "path/filepath" + "testing" + "time" + + "github.com/golang-migrate/migrate/v4" + sqlite_migrate "github.com/golang-migrate/migrate/v4/database/sqlite" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/stretchr/testify/require" + _ "modernc.org/sqlite" // Register relevant drivers. +) + +const ( + // sqliteOptionPrefix is the string prefix sqlite uses to set various + // options. This is used in the following format: + // * sqliteOptionPrefix || option_name = option_value. + sqliteOptionPrefix = "_pragma" + + // sqliteTxLockImmediate is a dsn option used to ensure that write + // transactions are started immediately. + sqliteTxLockImmediate = "_txlock=immediate" + + // defaultMaxConns is the number of permitted active and idle + // connections. We want to limit this so it isn't unlimited. We use the + // same value for the number of idle connections as, this can speed up + // queries given a new connection doesn't need to be established each + // time. + defaultMaxConns = 25 + + // defaultConnMaxLifetime is the maximum amount of time a connection can + // be reused for before it is closed. + defaultConnMaxLifetime = 10 * time.Minute +) + +var ( + // sqliteSchemaReplacements is a map of schema strings that need to be + // replaced for sqlite. There currently aren't any replacements, because + // the SQL files are written with SQLite compatibility in mind. + sqliteSchemaReplacements = map[string]string{} +) + +// SqliteConfig holds all the config arguments needed to interact with our +// sqlite DB. +// +// nolint: lll +type SqliteConfig struct { + // SkipMigrations if true, then all the tables will be created on start + // up if they don't already exist. + SkipMigrations bool `long:"skipmigrations" description:"Skip applying migrations on startup."` + + // SkipMigrationDbBackup if true, then a backup of the database will not + // be created before applying migrations. + SkipMigrationDbBackup bool `long:"skipmigrationdbbackup" description:"Skip creating a backup of the database before applying migrations."` + + // DatabaseFileName is the full file path where the database file can be + // found. + DatabaseFileName string `long:"dbfile" description:"The full path to the database."` +} + +// SqliteStore is a sqlite3 based database for the Taproot Asset daemon. +type SqliteStore struct { + cfg *SqliteConfig + + *BaseDB +} + +// NewSqliteStore attempts to open a new sqlite database based on the passed +// config. +func NewSqliteStore(cfg *SqliteConfig) (*SqliteStore, error) { + // The set of pragma options are accepted using query options. For now + // we only want to ensure that foreign key constraints are properly + // enforced. + pragmaOptions := []struct { + name string + value string + }{ + { + name: "foreign_keys", + value: "on", + }, + { + name: "journal_mode", + value: "WAL", + }, + { + name: "busy_timeout", + value: "5000", + }, + { + // With the WAL mode, this ensures that we also do an + // extra WAL sync after each transaction. The normal + // sync mode skips this and gives better performance, + // but risks durability. + name: "synchronous", + value: "full", + }, + { + // This is used to ensure proper durability for users + // running on Mac OS. It uses the correct fsync system + // call to ensure items are fully flushed to disk. + name: "fullfsync", + value: "true", + }, + } + sqliteOptions := make(url.Values) + for _, option := range pragmaOptions { + sqliteOptions.Add( + sqliteOptionPrefix, + fmt.Sprintf("%v=%v", option.name, option.value), + ) + } + + // Construct the DSN which is just the database file name, appended + // with the series of pragma options as a query URL string. For more + // details on the formatting here, see the modernc.org/sqlite docs: + // https://pkg.go.dev/modernc.org/sqlite#Driver.Open. + dsn := fmt.Sprintf( + "%v?%v&%v", cfg.DatabaseFileName, sqliteOptions.Encode(), + sqliteTxLockImmediate, + ) + db, err := sql.Open("sqlite", dsn) + if err != nil { + return nil, err + } + + db.SetMaxOpenConns(defaultMaxConns) + db.SetMaxIdleConns(defaultMaxConns) + db.SetConnMaxLifetime(defaultConnMaxLifetime) + + queries := sqlc.NewSqlite(db) + s := &SqliteStore{ + cfg: cfg, + BaseDB: &BaseDB{ + DB: db, + Queries: queries, + }, + } + + // Now that the database is open, populate the database with our set of + // schemas based on our embedded in-memory file system. + if !cfg.SkipMigrations { + if err := s.ExecuteMigrations(s.backupAndMigrate); err != nil { + return nil, fmt.Errorf("error executing migrations: "+ + "%w", err) + } + } + + return s, nil +} + +// backupSqliteDatabase creates a backup of the given SQLite database. +func backupSqliteDatabase(srcDB *sql.DB, dbFullFilePath string) error { + if srcDB == nil { + return fmt.Errorf("backup source database is nil") + } + + // Create a database backup file full path from the given source + // database full file path. + // + // Get the current time and format it as a Unix timestamp in + // nanoseconds. + timestamp := time.Now().UnixNano() + + // Add the timestamp to the backup name. + backupFullFilePath := fmt.Sprintf( + "%s.%d.backup", dbFullFilePath, timestamp, + ) + + log.Infof("Creating backup of database file: %v -> %v", + dbFullFilePath, backupFullFilePath) + + // Create the database backup. + vacuumIntoQuery := "VACUUM INTO ?;" + stmt, err := srcDB.Prepare(vacuumIntoQuery) + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec(backupFullFilePath) + if err != nil { + return err + } + + return nil +} + +// backupAndMigrate is a helper function that creates a database backup before +// initiating the migration, and then migrates the database to the latest +// version. +func (s *SqliteStore) backupAndMigrate(mig *migrate.Migrate, + currentDbVersion int, maxMigrationVersion uint) error { + + // Determine if a database migration is necessary given the current + // database version and the maximum migration version. + versionUpgradePending := currentDbVersion < int(maxMigrationVersion) + if !versionUpgradePending { + log.Infof("Current database version is up-to-date, skipping "+ + "migration attempt and backup creation "+ + "(current_db_version=%v, max_migration_version=%v)", + currentDbVersion, maxMigrationVersion) + return nil + } + + // At this point, we know that a database migration is necessary. + // Create a backup of the database before starting the migration. + if !s.cfg.SkipMigrationDbBackup { + log.Infof("Creating database backup (before applying " + + "migration(s))") + + err := backupSqliteDatabase(s.DB, s.cfg.DatabaseFileName) + if err != nil { + return err + } + } else { + log.Infof("Skipping database backup creation before applying " + + "migration(s)") + } + + log.Infof("Applying migrations to database") + return mig.Up() +} + +// ExecuteMigrations runs migrations for the sqlite database, depending on the +// target given, either all migrations or up to a given version. +func (s *SqliteStore) ExecuteMigrations(target MigrationTarget, + optFuncs ...MigrateOpt) error { + + opts := defaultMigrateOptions() + for _, optFunc := range optFuncs { + optFunc(opts) + } + + driver, err := sqlite_migrate.WithInstance( + s.DB, &sqlite_migrate.Config{}, + ) + if err != nil { + return fmt.Errorf("error creating sqlite migration: %w", err) + } + + sqliteFS := newReplacerFS(sqlSchemas, sqliteSchemaReplacements) + return applyMigrations( + sqliteFS, driver, "sqlc/migrations", "sqlite", target, opts, + ) +} + +// NewTestSqliteDB is a helper function that creates an SQLite database for +// testing. +func NewTestSqliteDB(t *testing.T) *SqliteStore { + t.Helper() + + // TODO(roasbeef): if we pass :memory: for the file name, then we get + // an in mem version to speed up tests + dbPath := filepath.Join(t.TempDir(), "tmp.db") + t.Logf("Creating new SQLite DB handle for testing: %s", dbPath) + + return NewTestSqliteDbHandleFromPath(t, dbPath) +} + +// NewTestSqliteDbHandleFromPath is a helper function that creates a SQLite +// database handle given a database file path. +func NewTestSqliteDbHandleFromPath(t *testing.T, dbPath string) *SqliteStore { + t.Helper() + + sqlDB, err := NewSqliteStore(&SqliteConfig{ + DatabaseFileName: dbPath, + SkipMigrations: false, + }) + require.NoError(t, err) + + t.Cleanup(func() { + require.NoError(t, sqlDB.DB.Close()) + }) + + return sqlDB +} + +// NewTestSqliteDBWithVersion is a helper function that creates an SQLite +// database for testing and migrates it to the given version. +func NewTestSqliteDBWithVersion(t *testing.T, version uint) *SqliteStore { + t.Helper() + + t.Logf("Creating new SQLite DB for testing, migrating to version %d", + version) + + // TODO(roasbeef): if we pass :memory: for the file name, then we get + // an in mem version to speed up tests + dbFileName := filepath.Join(t.TempDir(), "tmp.db") + sqlDB, err := NewSqliteStore(&SqliteConfig{ + DatabaseFileName: dbFileName, + SkipMigrations: true, + }) + require.NoError(t, err) + + err = sqlDB.ExecuteMigrations(TargetVersion(version)) + require.NoError(t, err) + + t.Cleanup(func() { + require.NoError(t, sqlDB.DB.Close()) + }) + + return sqlDB +} diff --git a/go.mod b/go.mod index 4f7cf4894..362e02e3d 100644 --- a/go.mod +++ b/go.mod @@ -9,9 +9,13 @@ require ( github.com/btcsuite/btcwallet/walletdb v1.4.4 github.com/davecgh/go-spew v1.1.1 github.com/go-errors/errors v1.0.1 + github.com/golang-migrate/migrate/v4 v4.17.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 github.com/improbable-eng/grpc-web v0.12.0 + github.com/jackc/pgconn v1.14.3 + github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 github.com/jessevdk/go-flags v1.4.0 + github.com/lib/pq v1.10.9 github.com/lightninglabs/faraday v0.2.14-alpha github.com/lightninglabs/faraday/frdrpc v1.0.0 github.com/lightninglabs/lightning-node-connect v0.3.2-alpha.0.20240822142323-ee4e7ff52f83 @@ -34,6 +38,7 @@ require ( github.com/lightningnetwork/lnd/tor v1.1.2 github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f github.com/mwitkow/grpc-proxy v0.0.0-20230212185441-f345521cb9c9 + github.com/ory/dockertest/v3 v3.10.0 github.com/stretchr/testify v1.9.0 github.com/urfave/cli v1.22.9 go.etcd.io/bbolt v1.3.11 @@ -45,6 +50,7 @@ require ( google.golang.org/protobuf v1.34.2 gopkg.in/macaroon-bakery.v2 v2.1.0 gopkg.in/macaroon.v2 v2.1.0 + modernc.org/sqlite v1.30.0 ) require ( @@ -94,7 +100,6 @@ require ( github.com/go-viper/mapstructure/v2 v2.2.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang-jwt/jwt/v4 v4.5.1 // indirect - github.com/golang-migrate/migrate/v4 v4.17.0 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/btree v1.0.1 // indirect @@ -110,8 +115,6 @@ require ( github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect - github.com/jackc/pgconn v1.14.3 // indirect - github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 // indirect github.com/jackc/pgio v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgproto3/v2 v2.3.3 // indirect @@ -131,7 +134,6 @@ require ( github.com/kkdai/bstream v1.0.0 // indirect github.com/klauspost/compress v1.17.9 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect - github.com/lib/pq v1.10.9 // indirect github.com/libdns/libdns v0.2.1 // indirect github.com/lightninglabs/aperture v0.3.4-beta // indirect github.com/lightninglabs/gozmq v0.0.0-20191113021534-d20a764486bf // indirect @@ -157,7 +159,6 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.0.2 // indirect github.com/opencontainers/runc v1.1.14 // indirect - github.com/ory/dockertest/v3 v3.10.0 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_golang v1.14.0 // indirect @@ -220,7 +221,6 @@ require ( modernc.org/libc v1.50.9 // indirect modernc.org/mathutil v1.6.0 // indirect modernc.org/memory v1.8.0 // indirect - modernc.org/sqlite v1.30.0 // indirect modernc.org/strutil v1.2.0 // indirect modernc.org/token v1.1.0 // indirect nhooyr.io/websocket v1.8.7 // indirect diff --git a/log.go b/log.go index 3b594e395..d02b034df 100644 --- a/log.go +++ b/log.go @@ -7,6 +7,7 @@ import ( "github.com/lightninglabs/lightning-node-connect/mailbox" "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightninglabs/lightning-terminal/autopilotserver" + "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/firewall" "github.com/lightninglabs/lightning-terminal/firewalldb" mid "github.com/lightninglabs/lightning-terminal/rpcmiddleware" @@ -89,6 +90,7 @@ func SetupLoggers(root *build.RotatingLogWriter, intercept signal.Interceptor) { lnd.AddSubLogger( root, subservers.Subsystem, intercept, subservers.UseLogger, ) + lnd.AddSubLogger(root, db.Subsystem, intercept, db.UseLogger) // Add daemon loggers to lnd's root logger. faraday.SetupLoggers(root, intercept) diff --git a/scripts/gen_sqlc_docker.sh b/scripts/gen_sqlc_docker.sh new file mode 100755 index 000000000..16db97f2c --- /dev/null +++ b/scripts/gen_sqlc_docker.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +set -e + +# restore_files is a function to restore original schema files. +restore_files() { + echo "Restoring SQLite bigint patch..." + for file in db/sqlc/migrations/*.up.sql.bak; do + mv "$file" "${file%.bak}" + done +} + +# Set trap to call restore_files on script exit. This makes sure the old files +# are always restored. +trap restore_files EXIT + +# Directory of the script file, independent of where it's called from. +DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +# Use the user's cache directories +GOCACHE=$(go env GOCACHE) +GOMODCACHE=$(go env GOMODCACHE) + +# SQLite doesn't support "BIGINT PRIMARY KEY" for auto-incrementing primary +# keys, only "INTEGER PRIMARY KEY". Internally it uses 64-bit integers for +# numbers anyway, independent of the column type. So we can just use +# "INTEGER PRIMARY KEY" and it will work the same under the hood, giving us +# auto incrementing 64-bit integers. +# _BUT_, sqlc will generate Go code with int32 if we use "INTEGER PRIMARY KEY", +# even though we want int64. So before we run sqlc, we need to patch the +# source schema SQL files to use "BIGINT PRIMARY KEY" instead of "INTEGER +# PRIMARY KEY". +echo "Applying SQLite bigint patch..." +for file in db/sqlc/migrations/*.up.sql; do + echo "Patching $file" + sed -i.bak -E 's/INTEGER PRIMARY KEY/BIGINT PRIMARY KEY/g' "$file" +done + +echo "Generating sql models and queries in go..." + +# Run the script to generate the new generated code. Once the script exits, we +# use `trap` to make sure all files are restored. +docker run \ + --rm \ + --user "$UID:$(id -g)" \ + -e UID=$UID \ + -v "$DIR/../:/build" \ + -w /build \ + sqlc/sqlc:1.25.0 generate diff --git a/sqlc.yaml b/sqlc.yaml new file mode 100644 index 000000000..a6d8f2c54 --- /dev/null +++ b/sqlc.yaml @@ -0,0 +1,10 @@ +version: "2" +sql: + - engine: "postgresql" + schema: "db/sqlc/migrations" + queries: "db/sqlc/queries" + gen: + go: + out: db/sqlc + package: sqlc + emit_interface: true