Skip to content

Commit 0a320ea

Browse files
committed
feat: add provider auto-detection
Point `nix-auth login <hostname>` and it should figure out what provider it is.
1 parent 4fd57b5 commit 0a320ea

File tree

8 files changed

+318
-55
lines changed

8 files changed

+318
-55
lines changed

README.md

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ to get those tokens in the right place.
88

99
## Features
1010

11-
- OAuth device flow authentication (no manual token creation needed)
11+
- OAuth device flow authentication when possible (no manual token creation needed)
1212
- Support for multiple providers (GitHub, GitHub Enterprise, GitLab, Gitea, and Forgejo)
1313
- Token storage in `~/.config/nix/nix.conf`
1414
- Token validation and status checking
@@ -61,34 +61,25 @@ go build .
6161

6262
### Login
6363

64-
Authenticate with GitHub (default provider):
65-
66-
```bash
67-
nix-auth login
68-
```
69-
70-
Authenticate with GitLab:
64+
Authenticate with a provider:
7165

7266
```bash
67+
# Using provider aliases
68+
nix-auth login # defaults to github
69+
nix-auth login github
7370
nix-auth login gitlab
74-
```
75-
76-
Authenticate with GitHub Enterprise or GitLab self-hosted:
77-
78-
```bash
79-
# GitHub Enterprise
80-
nix-auth login github --host github.company.com --client-id <your-client-id>
81-
82-
# GitLab self-hosted
83-
nix-auth login gitlab --host gitlab.company.com --client-id <your-application-id>
84-
85-
# Gitea (uses Personal Access Token flow)
8671
nix-auth login gitea
87-
nix-auth login gitea --host gitea.company.com
72+
nix-auth login codeberg
73+
74+
# Using hosts with auto-detection
75+
nix-auth login github.com
76+
nix-auth login gitlab.company.com # auto-detects provider type
77+
nix-auth login gitea.company.com # auto-detects provider type
8878

89-
# Forgejo (uses Personal Access Token flow)
90-
nix-auth login codeberg # for codeberg.org
91-
nix-auth login forgejo --host git.company.com # for self-hosted Forgejo (--host required)
79+
# Explicit provider specification
80+
nix-auth login git.company.com --provider forgejo
81+
nix-auth login github.company.com --provider github --client-id <your-client-id>
82+
nix-auth login gitlab.company.com --provider gitlab --client-id <your-application-id>
9283
```
9384

9485
The tool will:

cmd/login.go

Lines changed: 126 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"os"
77
"strings"
8+
"time"
89

910
"github.com/numtide/nix-auth/internal/config"
1011
"github.com/numtide/nix-auth/internal/provider"
@@ -13,53 +14,123 @@ import (
1314
)
1415

1516
var loginCmd = &cobra.Command{
16-
Use: "login [provider]",
17+
Use: "login [provider-or-host]",
1718
Short: "Authenticate with a provider and save the access token",
18-
Long: `Authenticate with a provider (GitHub, GitLab, Gitea, Forgejo, etc.) using OAuth device flow
19-
and save the access token to your nix.conf for use with Nix flakes.`,
20-
Example: ` nix-auth login # defaults to GitHub
19+
Long: `Authenticate with a provider using OAuth device flow (or Personal Access Token for Gitea/Forgejo)
20+
and save the access token to your nix.conf for use with Nix flakes.
21+
22+
You can specify either:
23+
- A provider alias (github, gitlab, gitea, codeberg) - uses default host for that provider
24+
- A host (e.g., github.com, git.company.com) - auto-detects provider type by querying API
25+
26+
Notes:
27+
- The --provider flag only works when specifying a host, not with provider aliases
28+
- For Forgejo, you must specify a host as it has no default: nix-auth login <host> --provider forgejo
29+
- Using both a provider alias and --provider flag will result in an error`,
30+
Example: ` # Using provider aliases
31+
nix-auth login # defaults to github
2132
nix-auth login github
2233
nix-auth login gitlab
2334
nix-auth login gitea
24-
nix-auth login codeberg # for codeberg.org
25-
nix-auth login forgejo --host git.company.com # --host required for forgejo
26-
nix-auth login github --host github.company.com --client-id abc123
27-
nix-auth login gitea --host gitea.company.com`,
28-
Args: cobra.MaximumNArgs(1),
29-
RunE: runLogin,
35+
nix-auth login codeberg
36+
37+
# Using hosts with auto-detection
38+
nix-auth login github.com
39+
nix-auth login gitlab.company.com # auto-detects provider type
40+
nix-auth login git.company.com # auto-detects provider type
41+
42+
# Explicit provider specification
43+
nix-auth login git.company.com --provider forgejo
44+
nix-auth login github.company.com --client-id abc123`,
45+
Args: cobra.MaximumNArgs(1),
46+
RunE: runLogin,
3047
}
3148

3249
var (
33-
loginHost string
50+
loginProvider string
3451
loginClientID string
52+
loginForce bool
53+
loginTimeout int
54+
loginDryRun bool
3555
)
3656

3757
func init() {
38-
loginCmd.Flags().StringVar(&loginHost, "host", "", "Custom host (e.g., github.company.com)")
39-
loginCmd.Flags().StringVar(&loginClientID, "client-id", "", "OAuth client ID (required for self-hosted instances)")
58+
loginCmd.Flags().StringVar(&loginProvider, "provider", "auto", "Provider type when using a host (auto, github, gitlab, gitea, forgejo, codeberg)")
59+
loginCmd.Flags().StringVar(&loginClientID, "client-id", "", "OAuth client ID (required for GitHub Enterprise, optional for others)")
60+
loginCmd.Flags().BoolVar(&loginForce, "force", false, "Skip confirmation prompt when replacing existing tokens")
61+
loginCmd.Flags().IntVar(&loginTimeout, "timeout", 30, "Timeout in seconds for network operations")
62+
loginCmd.Flags().BoolVar(&loginDryRun, "dry-run", false, "Preview what would happen without authenticating")
4063
}
4164

4265
func runLogin(cmd *cobra.Command, args []string) error {
43-
providerName := "github" // default
66+
var host string
67+
var providerName string
68+
var prov provider.Provider
69+
70+
// Parse the input
71+
input := "github" // default
4472
if len(args) > 0 {
45-
providerName = strings.ToLower(args[0])
73+
input = strings.ToLower(args[0])
4674
}
4775

48-
// Get provider
49-
prov, ok := provider.Get(providerName)
50-
if !ok {
51-
available := strings.Join(provider.List(), ", ")
52-
return fmt.Errorf("unknown provider '%s'. Available providers: %s", providerName, available)
76+
// First, determine if we're dealing with a host or provider alias
77+
isProviderAlias := false
78+
if _, ok := provider.Get(input); ok {
79+
isProviderAlias = true
80+
// Check for conflicts
81+
if loginProvider != "auto" && loginProvider != input {
82+
return fmt.Errorf("cannot use --provider %s with provider alias '%s'\n"+
83+
"Use: nix-auth login %s", loginProvider, input, input)
84+
}
5385
}
5486

55-
// Determine host
56-
host := prov.Host()
57-
if loginHost != "" {
58-
host = loginHost
59-
}
87+
if isProviderAlias {
88+
// Handle provider alias
89+
prov, _ = provider.Get(input)
90+
providerName = input
91+
host = prov.Host()
6092

61-
// Always set the host (even if it's the default)
62-
prov.SetHost(host)
93+
// For providers without a default host, require explicit host
94+
if host == "" {
95+
return fmt.Errorf("provider '%s' requires a host\n"+
96+
"Use: nix-auth login <host> --provider %s", input, input)
97+
}
98+
} else {
99+
// It's a host
100+
host = input
101+
102+
// Determine the provider
103+
if loginProvider == "auto" {
104+
// Auto-detect provider type
105+
fmt.Printf("Detecting provider type for %s by querying API...\n", host)
106+
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(loginTimeout)*time.Second)
107+
defer cancel()
108+
109+
detectedProvider, err := provider.DetectProviderFromHost(ctx, host)
110+
if err != nil {
111+
return fmt.Errorf("failed to detect provider for %s: %w\n"+
112+
"Try: nix-auth login %s --provider <github|gitlab|gitea|forgejo>",
113+
host, err, host)
114+
}
115+
116+
providerName = detectedProvider
117+
fmt.Printf("Detected: %s\n\n", providerName)
118+
} else {
119+
// Use explicitly specified provider
120+
providerName = loginProvider
121+
}
122+
123+
// Get the provider instance
124+
var ok bool
125+
prov, ok = provider.Get(providerName)
126+
if !ok {
127+
available := strings.Join(provider.List(), ", ")
128+
return fmt.Errorf("unknown provider '%s'. Available providers: %s", providerName, available)
129+
}
130+
131+
// Set the host on the provider
132+
prov.SetHost(host)
133+
}
63134

64135
// Set client ID: use flag, fallback to environment variable
65136
clientID := loginClientID
@@ -78,14 +149,28 @@ func runLogin(cmd *cobra.Command, args []string) error {
78149

79150
fmt.Printf("Authenticating with %s (%s)...\n", prov.Name(), host)
80151

152+
// If dry-run, show what would happen and exit
153+
if loginDryRun {
154+
fmt.Println("\nDry-run mode: Preview of what would happen:")
155+
fmt.Printf("- Provider: %s\n", prov.Name())
156+
fmt.Printf("- Host: %s\n", host)
157+
fmt.Printf("- OAuth scopes: %s\n", strings.Join(prov.GetScopes(), ", "))
158+
if clientID != "" {
159+
fmt.Printf("- Client ID: %s\n", clientID)
160+
}
161+
fmt.Printf("- Config file: %s\n", configPath)
162+
fmt.Println("\nNo authentication performed. Run without --dry-run to authenticate.")
163+
return nil
164+
}
165+
81166
// Check if token already exists
82167
cfg, err := config.New(configPath)
83168
if err != nil {
84169
return fmt.Errorf("failed to initialize config: %w", err)
85170
}
86171

87172
existingToken, _ := cfg.GetToken(host)
88-
if existingToken != "" {
173+
if existingToken != "" && !loginForce {
89174
fmt.Printf("A token for %s already exists. Do you want to replace it? [y/N] ", host)
90175
var response string
91176
fmt.Scanln(&response)
@@ -96,10 +181,21 @@ func runLogin(cmd *cobra.Command, args []string) error {
96181
}
97182

98183
// Perform authentication
99-
ctx := context.Background()
184+
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(loginTimeout)*time.Second)
185+
defer cancel()
100186
token, err := prov.Authenticate(ctx)
101187
if err != nil {
102-
return fmt.Errorf("authentication failed: %w", err)
188+
errMsg := fmt.Sprintf("authentication failed: %v", err)
189+
if strings.Contains(err.Error(), "context deadline exceeded") {
190+
errMsg += fmt.Sprintf("\n\nThe operation timed out after %d seconds. Try:\n"+
191+
"- Increasing the timeout: --timeout 60\n"+
192+
"- Checking your internet connection\n"+
193+
"- Verifying the host is accessible: curl https://%s", loginTimeout, host)
194+
} else if strings.Contains(err.Error(), "client ID") {
195+
errMsg += "\n\nFor self-hosted instances, you need to create an OAuth application.\n" +
196+
"See the instructions above or use --dry-run to preview the configuration."
197+
}
198+
return fmt.Errorf(errMsg)
103199
}
104200

105201
// Validate token

internal/provider/detection.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package provider
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net/http"
7+
"time"
8+
)
9+
10+
// DetectProviderFromHost attempts to identify the provider type by querying various API endpoints
11+
func DetectProviderFromHost(ctx context.Context, host string) (string, error) {
12+
// Create a client with timeout
13+
client := &http.Client{
14+
Timeout: 10 * time.Second,
15+
}
16+
17+
// Try each registered provider
18+
for name, provider := range Registry {
19+
// Create a new instance to avoid mutating the registered provider
20+
p := provider
21+
if p.DetectHost(ctx, client, host) {
22+
return name, nil
23+
}
24+
}
25+
26+
return "", fmt.Errorf("unable to detect provider type for host: %s", host)
27+
}

internal/provider/forgejo.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,41 @@ func (f *ForgejoProvider) SetHost(host string) {
2626
func (f *ForgejoProvider) SetClientID(clientID string) {
2727
}
2828

29+
// DetectHost checks if the given host is a Forgejo instance
30+
func (f *ForgejoProvider) DetectHost(ctx context.Context, client *http.Client, host string) bool {
31+
// Known Forgejo/Codeberg host
32+
if strings.ToLower(host) == "codeberg.org" {
33+
return true
34+
}
35+
36+
// For other hosts, check if it's a Forgejo instance using the version endpoint
37+
baseURL := fmt.Sprintf("https://%s", host)
38+
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/api/v1/version", baseURL), nil)
39+
if err != nil {
40+
return false
41+
}
42+
43+
resp, err := client.Do(req)
44+
if err != nil {
45+
return false
46+
}
47+
defer resp.Body.Close()
48+
49+
if resp.StatusCode == http.StatusOK {
50+
var data struct {
51+
Version string `json:"version"`
52+
}
53+
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
54+
return false
55+
}
56+
// Forgejo includes "forgejo" in the version string
57+
if strings.Contains(strings.ToLower(data.Version), "forgejo") {
58+
return true
59+
}
60+
}
61+
return false
62+
}
63+
2964
func (f *ForgejoProvider) getBaseURL() string {
3065
if f.host != "" {
3166
return fmt.Sprintf("https://%s", f.host)

internal/provider/gitea.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,42 @@ func (g *GiteaProvider) SetHost(host string) {
2525
func (g *GiteaProvider) SetClientID(clientID string) {
2626
}
2727

28+
// DetectHost checks if the given host is a Gitea instance
29+
func (g *GiteaProvider) DetectHost(ctx context.Context, client *http.Client, host string) bool {
30+
// Known Gitea hosts
31+
lowerHost := strings.ToLower(host)
32+
if lowerHost == "gitea.com" || lowerHost == "gitea.io" {
33+
return true
34+
}
35+
36+
// For other hosts, check if it's a Gitea instance using the version endpoint
37+
baseURL := fmt.Sprintf("https://%s", host)
38+
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/api/v1/version", baseURL), nil)
39+
if err != nil {
40+
return false
41+
}
42+
43+
resp, err := client.Do(req)
44+
if err != nil {
45+
return false
46+
}
47+
defer resp.Body.Close()
48+
49+
if resp.StatusCode == http.StatusOK {
50+
var data struct {
51+
Version string `json:"version"`
52+
}
53+
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
54+
return false
55+
}
56+
// Check if it's NOT Forgejo (Forgejo includes "forgejo" in version string)
57+
if data.Version != "" && !strings.Contains(strings.ToLower(data.Version), "forgejo") {
58+
return true
59+
}
60+
}
61+
return false
62+
}
63+
2864
func (g *GiteaProvider) getBaseURL() string {
2965
if g.host != "" {
3066
return fmt.Sprintf("https://%s", g.host)

0 commit comments

Comments
 (0)