diff --git a/db.go b/db.go index d78062e5..6ca9b2d2 100644 --- a/db.go +++ b/db.go @@ -29,6 +29,8 @@ import ( // type DbMap struct { ctx context.Context + // TimeOutInterval is used to set the timeout interval for transactions based on DbMap + TimeOutInterval time.Duration // Db handle to use with this map Db *sql.DB diff --git a/gorp.go b/gorp.go index fc654567..3ee417b0 100644 --- a/gorp.go +++ b/gorp.go @@ -199,6 +199,34 @@ func extractExecutorAndContext(e SqlExecutor) (executor, context.Context) { return nil, nil } +func extractTimeOutInterval(e SqlExecutor) time.Duration { + switch m := e.(type) { + case *DbMap: + return m.TimeOutInterval + case *Transaction: + return m.TimeOutInterval + } + return 0 +} + +func createNewContext(ctx context.Context, duration time.Duration) (context.Context, context.CancelFunc) { + if ctx == nil { + if duration == 0 { + return nil, nil + } else { + return context.WithTimeout(context.Background(), duration) + } + + } else { + if duration == 0 { + return ctx, nil + } else { + return context.WithTimeout(ctx, duration) + } + } + +} + // maybeExpandNamedQuery checks the given arg to see if it's eligible to be used // as input to a named query. If so, it rewrites the query to use // dialect-dependent bindvars and instantiates the corresponding slice of @@ -625,6 +653,11 @@ func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error { func exec(e SqlExecutor, query string, args ...interface{}) (sql.Result, error) { executor, ctx := extractExecutorAndContext(e) + timeout := extractTimeOutInterval(e) + ctx, cancel := createNewContext(ctx, timeout) + if cancel != nil { + defer cancel() + } if ctx != nil { return executor.ExecContext(ctx, query, args...) @@ -635,6 +668,11 @@ func exec(e SqlExecutor, query string, args ...interface{}) (sql.Result, error) func prepare(e SqlExecutor, query string) (*sql.Stmt, error) { executor, ctx := extractExecutorAndContext(e) + timeout := extractTimeOutInterval(e) + ctx, cancel := createNewContext(ctx, timeout) + if cancel != nil { + defer cancel() + } if ctx != nil { return executor.PrepareContext(ctx, query) @@ -645,6 +683,11 @@ func prepare(e SqlExecutor, query string) (*sql.Stmt, error) { func queryRow(e SqlExecutor, query string, args ...interface{}) *sql.Row { executor, ctx := extractExecutorAndContext(e) + timeout := extractTimeOutInterval(e) + ctx, cancel := createNewContext(ctx, timeout) + if cancel != nil { + defer cancel() + } if ctx != nil { return executor.QueryRowContext(ctx, query, args...) @@ -655,6 +698,11 @@ func queryRow(e SqlExecutor, query string, args ...interface{}) *sql.Row { func query(e SqlExecutor, query string, args ...interface{}) (*sql.Rows, error) { executor, ctx := extractExecutorAndContext(e) + timeout := extractTimeOutInterval(e) + ctx, cancel := createNewContext(ctx, timeout) + if cancel != nil { + defer cancel() + } if ctx != nil { return executor.QueryContext(ctx, query, args...) diff --git a/transaction.go b/transaction.go index d505d94c..b6cca4ab 100644 --- a/transaction.go +++ b/transaction.go @@ -15,10 +15,11 @@ import ( // of that transaction. Transactions should be terminated with // a call to Commit() or Rollback() type Transaction struct { - ctx context.Context - dbmap *DbMap - tx *sql.Tx - closed bool + ctx context.Context + TimeOutInterval time.Duration + dbmap *DbMap + tx *sql.Tx + closed bool } func (t *Transaction) WithContext(ctx context.Context) SqlExecutor {