Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 31 additions & 36 deletions backend/server/internal/database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type DB struct {
*gorm.DB
}

func OpenSQLite(dsn string, config *gorm.Config) (*DB, error) {
func OpenSQLite(dsn string, config *gorm.Config) (dbPtr *DB, err error) {
db, err := gorm.Open(sqlite.Open(dsn), config)
if err != nil {
return nil, fmt.Errorf("gorm.Open: %w", err)
Expand All @@ -33,7 +33,7 @@ func OpenSQLite(dsn string, config *gorm.Config) (*DB, error) {
return &DB{db}, nil
}

func OpenPostgres(dsn string, config *gorm.Config) (*DB, error) {
func OpenPostgres(dsn string, config *gorm.Config) (dbPtr *DB, err error) {
sqltrace.Register("pgx", &stdlib.Driver{}, sqltrace.WithServiceName("hishtory-api"))
sqlDb, err := sqltrace.Open("pgx", dsn)
if err != nil {
Expand All @@ -47,7 +47,7 @@ func OpenPostgres(dsn string, config *gorm.Config) (*DB, error) {
return &DB{db}, nil
}

func (db *DB) AddDatabaseTables() error {
func (db *DB) AddDatabaseTables() (err error) {
models := []any{
&shared.EncHistoryEntry{},
&Device{},
Expand All @@ -67,7 +67,7 @@ func (db *DB) AddDatabaseTables() error {
return nil
}

func (db *DB) CreateIndices() error {
func (db *DB) CreateIndices() (err error) {
// Note: If adding a new index here, consider manually running it on the prod DB using CONCURRENTLY to
// make server startup non-blocking. The benefit of this function is primarily for other people so they
// don't have to manually create these indexes.
Expand Down Expand Up @@ -99,7 +99,7 @@ func (db *DB) CreateIndices() error {
return nil
}

func (db *DB) Close() error {
func (db *DB) Close() (err error) {
rawDB, err := db.DB.DB()
if err != nil {
return fmt.Errorf("db.DB.DB: %w", err)
Expand All @@ -112,7 +112,7 @@ func (db *DB) Close() error {
return nil
}

func (db *DB) Ping() error {
func (db *DB) Ping() (err error) {
rawDB, err := db.DB.DB()
if err != nil {
return fmt.Errorf("db.DB.DB: %w", err)
Expand All @@ -125,7 +125,7 @@ func (db *DB) Ping() error {
return nil
}

func (db *DB) SetMaxIdleConns(n int) error {
func (db *DB) SetMaxIdleConns(n int) (err error) {
rawDB, err := db.DB.DB()
if err != nil {
return err
Expand All @@ -136,7 +136,7 @@ func (db *DB) SetMaxIdleConns(n int) error {
return nil
}

func (db *DB) Stats() (sql.DBStats, error) {
func (db *DB) Stats() (stats sql.DBStats, err error) {
rawDB, err := db.DB.DB()
if err != nil {
return sql.DBStats{}, fmt.Errorf("db.DB.DB: %w", err)
Expand All @@ -145,18 +145,17 @@ func (db *DB) Stats() (sql.DBStats, error) {
return rawDB.Stats(), nil
}

func (db *DB) DistinctUsers(ctx context.Context) (int64, error) {
func (db *DB) DistinctUsers(ctx context.Context) (num int64, err error) {
row := db.WithContext(ctx).Raw("SELECT COUNT(DISTINCT devices.user_id) FROM devices").Row()
var numDistinctUsers int64
err := row.Scan(&numDistinctUsers)
err = row.Scan(&num)
if err != nil {
return 0, fmt.Errorf("row.Scan: %w", err)
}

return numDistinctUsers, nil
return num, nil
}

func (db *DB) UserAlreadyExist(ctx context.Context, userID string) (bool, error) {
func (db *DB) UserAlreadyExist(ctx context.Context, userID string) (exists bool, err error) {
var cnt int64
tx := db.WithContext(ctx).Table("devices").Where("user_id = ?", userID).Count(&cnt)
if tx.Error != nil {
Expand All @@ -169,7 +168,7 @@ func (db *DB) UserAlreadyExist(ctx context.Context, userID string) (bool, error)
return false, nil
}

func (db *DB) DumpRequestCreate(ctx context.Context, req *shared.DumpRequest) error {
func (db *DB) DumpRequestCreate(ctx context.Context, req *shared.DumpRequest) (err error) {
tx := db.WithContext(ctx).Create(req)
if tx.Error != nil {
return fmt.Errorf("tx.Error: %w", tx.Error)
Expand All @@ -178,8 +177,7 @@ func (db *DB) DumpRequestCreate(ctx context.Context, req *shared.DumpRequest) er
return nil
}

func (db *DB) DumpRequestForUserAndDevice(ctx context.Context, userID, deviceID string) ([]*shared.DumpRequest, error) {
var dumpRequests []*shared.DumpRequest
func (db *DB) DumpRequestForUserAndDevice(ctx context.Context, userID, deviceID string) (dumpRequests []*shared.DumpRequest, err error) {
// Filter out ones requested by the hishtory instance that sent this request
tx := db.WithContext(ctx).Where("user_id = ? AND requesting_device_id != ?", userID, deviceID).Find(&dumpRequests)
if tx.Error != nil {
Expand All @@ -189,7 +187,7 @@ func (db *DB) DumpRequestForUserAndDevice(ctx context.Context, userID, deviceID
return dumpRequests, nil
}

func (db *DB) DumpRequestDeleteForUserAndDevice(ctx context.Context, userID, deviceID string) error {
func (db *DB) DumpRequestDeleteForUserAndDevice(ctx context.Context, userID, deviceID string) (err error) {
tx := db.WithContext(ctx).Delete(&shared.DumpRequest{}, "user_id = ? AND requesting_device_id = ?", userID, deviceID)
if tx.Error != nil {
return fmt.Errorf("tx.Error: %w", tx.Error)
Expand All @@ -198,7 +196,7 @@ func (db *DB) DumpRequestDeleteForUserAndDevice(ctx context.Context, userID, dev
return nil
}

func (db *DB) ApplyDeletionRequestsToBackend(ctx context.Context, requests []*shared.DeletionRequest) (int64, error) {
func (db *DB) ApplyDeletionRequestsToBackend(ctx context.Context, requests []*shared.DeletionRequest) (rowsAffected int64, err error) {
if len(requests) == 0 {
return 0, nil
}
Expand All @@ -213,7 +211,7 @@ func (db *DB) ApplyDeletionRequestsToBackend(ctx context.Context, requests []*sh
return db.DeleteMessagesFromBackend(ctx, userId, deletedMessages)
}

func (db *DB) UninstallDevice(ctx context.Context, userId, deviceId string) (int64, error) {
func (db *DB) UninstallDevice(ctx context.Context, userId, deviceId string) (rowsAffected int64, err error) {
// Note that this is deleting entries that are destined to be *read* by this device. If there are other devices on this account,
// those queues are unaffected.
r1 := db.WithContext(ctx).Where("user_id = ? AND device_id = ?", userId, deviceId).Delete(&shared.EncHistoryEntry{})
Expand All @@ -239,7 +237,7 @@ func (db *DB) UninstallDevice(ctx context.Context, userId, deviceId string) (int
return r1.RowsAffected + r2.RowsAffected + r3.RowsAffected, nil
}

func (db *DB) DeleteMessagesFromBackend(ctx context.Context, userId string, deletedMessages []shared.MessageIdentifier) (int64, error) {
func (db *DB) DeleteMessagesFromBackend(ctx context.Context, userId string, deletedMessages []shared.MessageIdentifier) (rowsAffected int64, err error) {
if len(deletedMessages) == 0 {
return 0, nil
}
Expand All @@ -248,7 +246,6 @@ func (db *DB) DeleteMessagesFromBackend(ctx context.Context, userId string, dele
return 0, fmt.Errorf("failed to delete entry because userId is empty")
}

var rowsAffected int64
for _, chunkOfMessages := range lo.Chunk(deletedMessages, 255) {
tx := db.WithContext(ctx).Where("false")
for _, message := range chunkOfMessages {
Expand Down Expand Up @@ -278,7 +275,7 @@ func (db *DB) DeleteMessagesFromBackend(ctx context.Context, userId string, dele
return rowsAffected, nil
}

func (db *DB) DeletionRequestInc(ctx context.Context, userID, deviceID string) error {
func (db *DB) DeletionRequestInc(ctx context.Context, userID, deviceID string) (err error) {
tx := db.WithContext(ctx).Exec("UPDATE deletion_requests SET read_count = read_count + 1 WHERE user_id = ? AND destination_device_id = ?", userID, deviceID)
if tx.Error != nil {
return fmt.Errorf("tx.Error: %w", tx.Error)
Expand All @@ -287,8 +284,7 @@ func (db *DB) DeletionRequestInc(ctx context.Context, userID, deviceID string) e
return nil
}

func (db *DB) DeletionRequestsForUserAndDevice(ctx context.Context, userID, deviceID string) ([]*shared.DeletionRequest, error) {
var deletionRequests []*shared.DeletionRequest
func (db *DB) DeletionRequestsForUserAndDevice(ctx context.Context, userID, deviceID string) (deletionRequests []*shared.DeletionRequest, err error) {
tx := db.WithContext(ctx).Where("user_id = ? AND destination_device_id = ?", userID, deviceID).Find(&deletionRequests)
if tx.Error != nil {
return nil, fmt.Errorf("tx.Error: %w", tx.Error)
Expand All @@ -297,7 +293,7 @@ func (db *DB) DeletionRequestsForUserAndDevice(ctx context.Context, userID, devi
return deletionRequests, nil
}

func (db *DB) DeletionRequestCreate(ctx context.Context, request *shared.DeletionRequest) error {
func (db *DB) DeletionRequestCreate(ctx context.Context, request *shared.DeletionRequest) (err error) {
userID := request.UserId

devices, err := db.DevicesForUser(ctx, userID)
Expand Down Expand Up @@ -333,7 +329,7 @@ func (db *DB) DeletionRequestCreate(ctx context.Context, request *shared.Deletio
return nil
}

func (db *DB) FeedbackCreate(ctx context.Context, feedback *shared.Feedback) error {
func (db *DB) FeedbackCreate(ctx context.Context, feedback *shared.Feedback) (err error) {
tx := db.WithContext(ctx).Create(feedback)
if tx.Error != nil {
return fmt.Errorf("tx.Error: %w", tx.Error)
Expand All @@ -342,7 +338,7 @@ func (db *DB) FeedbackCreate(ctx context.Context, feedback *shared.Feedback) err
return nil
}

func (db *DB) Clean(ctx context.Context) error {
func (db *DB) Clean(ctx context.Context) (err error) {
r := db.WithContext(ctx).Exec("DELETE FROM enc_history_entries WHERE read_count > 10")
if r.Error != nil {
return r.Error
Expand All @@ -355,9 +351,8 @@ func (db *DB) Clean(ctx context.Context) error {
return nil
}

func extractInt64FromRow(row *sql.Row) (int64, error) {
var ret int64
err := row.Scan(&ret)
func extractInt64FromRow(row *sql.Row) (ret int64, err error) {
err = row.Scan(&ret)
if err != nil {
return 0, fmt.Errorf("extractInt64FromRow: %w", err)
}
Expand All @@ -376,7 +371,7 @@ type ActiveUserStats struct {
DailyUninstalls int64
}

func (db *DB) GenerateAndStoreActiveUserStats(ctx context.Context) error {
func (db *DB) GenerateAndStoreActiveUserStats(ctx context.Context) (err error) {
if db.DB.Name() == "sqlite" {
// Not supported on sqlite
return nil
Expand Down Expand Up @@ -428,7 +423,7 @@ func (db *DB) GenerateAndStoreActiveUserStats(ctx context.Context) error {
}).Error
}

func (db *DB) SelfHostedDeepClean(ctx context.Context) error {
func (db *DB) SelfHostedDeepClean(ctx context.Context) (err error) {
if db.Name() == "sqlite" {
// sqlite doesn't support the `(now() - INTERVAL '90 days')` syntax used in the below queries.
return nil
Expand Down Expand Up @@ -482,8 +477,8 @@ func (db *DB) SelfHostedDeepClean(ctx context.Context) error {
})
}

func (db *DB) DeepClean(ctx context.Context) error {
err := db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
func (db *DB) DeepClean(ctx context.Context) (err error) {
err = db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Delete entries for users that have one device and are inactive
r := tx.Exec(`
CREATE TEMP TABLE temp_users_with_one_device AS (
Expand Down Expand Up @@ -548,7 +543,7 @@ func (db *DB) DeepClean(ctx context.Context) error {
if err != nil {
return err
}
err = db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err = db.WithContext(ctx).Transaction(func(tx *gorm.DB) (err error) {
r := tx.Exec(`
CREATE TEMP TABLE high_del_req_users AS (
SELECT user_id
Expand All @@ -572,7 +567,7 @@ func (db *DB) DeepClean(ctx context.Context) error {
if err != nil {
return err
}
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) (err error) {
// Delete entries for integration test users
r := tx.Exec(`
CREATE TEMP TABLE temp_inactive_devices AS (
Expand Down
15 changes: 6 additions & 9 deletions backend/server/internal/database/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ type Device struct {
UninstallDate time.Time `json:"uninstall_date"`
}

func (db *DB) CountAllDevices(ctx context.Context) (int64, error) {
var numDevices int64 = 0
func (db *DB) CountAllDevices(ctx context.Context) (numDevices int64, err error) {
tx := db.WithContext(ctx).Model(&Device{}).Count(&numDevices)
if tx.Error != nil {
return 0, fmt.Errorf("tx.Error: %w", tx.Error)
Expand All @@ -31,8 +30,7 @@ func (db *DB) CountAllDevices(ctx context.Context) (int64, error) {
return numDevices, nil
}

func (db *DB) CountDevicesForUser(ctx context.Context, userID string) (int64, error) {
var existingDevicesCount int64
func (db *DB) CountDevicesForUser(ctx context.Context, userID string) (existingDevicesCount int64, err error) {
tx := db.WithContext(ctx).Model(&Device{}).Where("user_id = ?", userID).Count(&existingDevicesCount)
if tx.Error != nil {
return 0, fmt.Errorf("tx.Error: %w", tx.Error)
Expand All @@ -41,7 +39,7 @@ func (db *DB) CountDevicesForUser(ctx context.Context, userID string) (int64, er
return existingDevicesCount, nil
}

func (db *DB) CreateDevice(ctx context.Context, device *Device) error {
func (db *DB) CreateDevice(ctx context.Context, device *Device) (err error) {
tx := db.WithContext(ctx).Create(device)
if tx.Error != nil {
return fmt.Errorf("tx.Error: %w", tx.Error)
Expand All @@ -50,12 +48,11 @@ func (db *DB) CreateDevice(ctx context.Context, device *Device) error {
return nil
}

func (db *DB) DevicesForUser(ctx context.Context, userID string) ([]*Device, error) {
var devices []*Device
tx := db.WithContext(ctx).Where("user_id = ? AND (uninstall_date IS NULL OR uninstall_date < '1971-01-01')", userID).Find(&devices)
func (db *DB) DevicesForUser(ctx context.Context, userID string) (devicesPtr []*Device, err error) {
tx := db.WithContext(ctx).Where("user_id = ? AND (uninstall_date IS NULL OR uninstall_date < '1971-01-01')", userID).Find(&devicesPtr)
if tx.Error != nil {
return nil, fmt.Errorf("tx.Error: %w", tx.Error)
}

return devices, nil
return devicesPtr, nil
}
20 changes: 9 additions & 11 deletions backend/server/internal/database/historyentries.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,16 @@ import (
"gorm.io/gorm"
)

func (db *DB) CountApproximateHistoryEntries(ctx context.Context) (int64, error) {
var numDbEntries int64
err := db.WithContext(ctx).Raw("SELECT reltuples::bigint FROM pg_class WHERE relname = 'enc_history_entries'").Row().Scan(&numDbEntries)
func (db *DB) CountApproximateHistoryEntries(ctx context.Context) (numDbEntries int64, err error) {
err = db.WithContext(ctx).Raw("SELECT reltuples::bigint FROM pg_class WHERE relname = 'enc_history_entries'").Row().Scan(&numDbEntries)
if err != nil {
return 0, fmt.Errorf("DB Error: %w", err)
}

return numDbEntries, nil
}

func (db *DB) AllHistoryEntriesForUser(ctx context.Context, userID string) ([]*shared.EncHistoryEntry, error) {
func (db *DB) AllHistoryEntriesForUser(ctx context.Context, userID string) (dedupedEntries []*shared.EncHistoryEntry, err error) {
var historyEntries []*shared.EncHistoryEntry
tx := db.WithContext(ctx).Where("user_id = ?", userID).Find(&historyEntries)

Expand All @@ -34,16 +33,15 @@ func (db *DB) AllHistoryEntriesForUser(ctx context.Context, userID string) ([]*s
}

// Convert the map back to a slice
dedupedEntries := make([]*shared.EncHistoryEntry, 0, len(uniqueEntries))
dedupedEntries = make([]*shared.EncHistoryEntry, 0, len(uniqueEntries))
for _, entry := range uniqueEntries {
dedupedEntries = append(dedupedEntries, entry)
}

return dedupedEntries, nil
}

func (db *DB) HistoryEntriesForDevice(ctx context.Context, deviceID string, limit int) ([]*shared.EncHistoryEntry, error) {
var historyEntries []*shared.EncHistoryEntry
func (db *DB) HistoryEntriesForDevice(ctx context.Context, deviceID string, limit int) (historyEntries []*shared.EncHistoryEntry, err error) {
tx := db.WithContext(ctx).Where("device_id = ? AND read_count < ? AND NOT is_from_same_device", deviceID, limit).Find(&historyEntries)

if tx.Error != nil {
Expand All @@ -53,7 +51,7 @@ func (db *DB) HistoryEntriesForDevice(ctx context.Context, deviceID string, limi
return historyEntries, nil
}

func (db *DB) AddHistoryEntries(ctx context.Context, entries ...*shared.EncHistoryEntry) error {
func (db *DB) AddHistoryEntries(ctx context.Context, entries ...*shared.EncHistoryEntry) (err error) {
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for _, entry := range entries {
resp := tx.Create(&entry)
Expand All @@ -65,7 +63,7 @@ func (db *DB) AddHistoryEntries(ctx context.Context, entries ...*shared.EncHisto
})
}

func (db *DB) AddHistoryEntriesForAllDevices(ctx context.Context, sourceDeviceId string, devices []*Device, entries []*shared.EncHistoryEntry) error {
func (db *DB) AddHistoryEntriesForAllDevices(ctx context.Context, sourceDeviceId string, devices []*Device, entries []*shared.EncHistoryEntry) (err error) {
chunkSize := 1000
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for _, device := range devices {
Expand All @@ -85,7 +83,7 @@ func (db *DB) AddHistoryEntriesForAllDevices(ctx context.Context, sourceDeviceId
})
}

func (db *DB) Unsafe_DeleteAllHistoryEntries(ctx context.Context) error {
func (db *DB) Unsafe_DeleteAllHistoryEntries(ctx context.Context) (err error) {
tx := db.WithContext(ctx).Exec("DELETE FROM enc_history_entries")
if tx.Error != nil {
return fmt.Errorf("tx.Error: %w", tx.Error)
Expand All @@ -94,6 +92,6 @@ func (db *DB) Unsafe_DeleteAllHistoryEntries(ctx context.Context) error {
return nil
}

func (db *DB) IncrementEntryReadCountsForDevice(ctx context.Context, deviceID string) error {
func (db *DB) IncrementEntryReadCountsForDevice(ctx context.Context, deviceID string) (err error) {
return db.WithContext(ctx).Exec("UPDATE enc_history_entries SET read_count = read_count + 1 WHERE device_id = ?", deviceID).Error
}
Loading