Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .golangci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ linters:
- $test
allow:
- $gostd
- github.com/golang/mock/gomock
- github.com/openfga/api/proto
- github.com/openfga/cli
- github.com/openfga/go-sdk
- github.com/openfga/openfga
- github.com/spf13/cobra
- github.com/stretchr
- go.uber.org/mock/gomock
funlen:
Expand Down
5 changes: 3 additions & 2 deletions cmd/model/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

"github.com/openfga/cli/internal/authorizationmodel"
"github.com/openfga/cli/internal/cmdutils"
"github.com/openfga/cli/internal/flags"
"github.com/openfga/cli/internal/output"
)

Expand Down Expand Up @@ -77,8 +78,8 @@ func init() {
getCmd.Flags().StringArray("field", []string{"model"}, "Fields to display, choices are: id, created_at and model") //nolint:lll
getCmd.Flags().Var(&getOutputFormat, "format", `Authorization model output format. Can be "fga" or "json"`)

if err := getCmd.MarkFlagRequired("store-id"); err != nil {
fmt.Printf("error setting flag as required - %v: %v\n", "cmd/models/get", err)
if err := flags.SetFlagRequired(getCmd, "store-id", "cmd/models/get", false); err != nil {
fmt.Printf("%v\n", err)
os.Exit(1)
}
}
5 changes: 3 additions & 2 deletions cmd/model/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (

"github.com/openfga/cli/internal/authorizationmodel"
"github.com/openfga/cli/internal/cmdutils"
"github.com/openfga/cli/internal/flags"
"github.com/openfga/cli/internal/output"
)

Expand Down Expand Up @@ -112,8 +113,8 @@ func init() {
listCmd.Flags().String("store-id", "", "Store ID")
listCmd.Flags().StringArray("field", []string{"id", "created_at"}, "Fields to display, choices are: id, created_at and model") //nolint:lll

if err := listCmd.MarkFlagRequired("store-id"); err != nil {
fmt.Printf("error setting flag as required - %v: %v\n", "cmd/models/list", err)
if err := flags.SetFlagRequired(listCmd, "store-id", "cmd/models/list", false); err != nil {
fmt.Printf("%v\n", err)
os.Exit(1)
}
}
5 changes: 3 additions & 2 deletions cmd/model/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/spf13/cobra"

"github.com/openfga/cli/internal/cmdutils"
"github.com/openfga/cli/internal/flags"
"github.com/openfga/cli/internal/output"
"github.com/openfga/cli/internal/storetest"
)
Expand Down Expand Up @@ -99,8 +100,8 @@ func init() {
testCmd.Flags().Bool("verbose", false, "Print verbose JSON output")
testCmd.Flags().Bool("suppress-summary", false, "Suppress the plain text summary output")

if err := testCmd.MarkFlagRequired("tests"); err != nil {
fmt.Printf("error setting flag as required - %v: %v\n", "cmd/models/test", err)
if err := flags.SetFlagRequired(testCmd, "tests", "cmd/models/test", false); err != nil {
fmt.Printf("%v\n", err)
os.Exit(1)
}
}
5 changes: 3 additions & 2 deletions cmd/model/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (

"github.com/openfga/cli/internal/authorizationmodel"
"github.com/openfga/cli/internal/cmdutils"
"github.com/openfga/cli/internal/flags"
"github.com/openfga/cli/internal/output"
"github.com/openfga/cli/internal/utils"
)
Expand Down Expand Up @@ -106,8 +107,8 @@ func init() {
writeCmd.Flags().String("file", "", "File Name. The file should have the model in the JSON or DSL format")
writeCmd.Flags().Var(&writeInputFormat, "format", `Authorization model input format. Can be "fga", "json", or "modular"`) //nolint:lll

if err := writeCmd.MarkFlagRequired("store-id"); err != nil {
fmt.Printf("error setting flag as required - %v: %v\n", "cmd/models/write", err)
if err := flags.SetFlagRequired(writeCmd, "store-id", "cmd/model/write", false); err != nil {
fmt.Printf("%v\n", err)
os.Exit(1)
}
}
18 changes: 6 additions & 12 deletions cmd/query/list-users.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/spf13/cobra"

"github.com/openfga/cli/internal/cmdutils"
"github.com/openfga/cli/internal/flags"
"github.com/openfga/cli/internal/output"
)

Expand Down Expand Up @@ -138,18 +139,11 @@ func init() {
listUsersCmd.Flags().String("relation", "", "Relation to evaluate on")
listUsersCmd.Flags().String("user-filter", "", "Filter the responses can be in the formats <type> (to filter objects and typed public bound access) or <type>#<relation> (to filter usersets)") //nolint:lll

if err := listUsersCmd.MarkFlagRequired("object"); err != nil {
fmt.Printf("error setting flag as required - %v: %v\n", "cmd/query/list-users", err)
os.Exit(1)
}

if err := listUsersCmd.MarkFlagRequired("relation"); err != nil {
fmt.Printf("error setting flag as required - %v: %v\n", "cmd/query/list-users", err)
os.Exit(1)
}

if err := listUsersCmd.MarkFlagRequired("user-filter"); err != nil {
fmt.Printf("error setting flag as required - %v: %v\n", "cmd/query/list-users", err)
if err := flags.SetFlagsRequired(
listUsersCmd,
[]string{"object", "relation", "user-filter"},
"cmd/query/list-users", false); err != nil {
fmt.Printf("%v\n", err)
os.Exit(1)
}
}
7 changes: 4 additions & 3 deletions cmd/query/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"os"

"github.com/spf13/cobra"

"github.com/openfga/cli/internal/flags"
)

// QueryCmd represents the query command.
Expand All @@ -48,9 +50,8 @@ func init() {
"Consistency preference for the request. Valid options are HIGHER_CONSISTENCY and MINIMIZE_LATENCY.",
)

err := QueryCmd.MarkPersistentFlagRequired("store-id")
if err != nil {
fmt.Print(err)
if err := flags.SetFlagRequired(QueryCmd, "store-id", "cmd/query/query", true); err != nil {
fmt.Printf("%v\n", err)
os.Exit(1)
}
}
6 changes: 3 additions & 3 deletions cmd/store/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

"github.com/openfga/cli/internal/cmdutils"
"github.com/openfga/cli/internal/confirmation"
"github.com/openfga/cli/internal/flags"
"github.com/openfga/cli/internal/output"
)

Expand Down Expand Up @@ -72,9 +73,8 @@ func init() {
deleteCmd.Flags().String("store-id", "", "Store ID")
deleteCmd.Flags().Bool("force", false, "Force delete without confirmation")

err := deleteCmd.MarkFlagRequired("store-id")
if err != nil {
fmt.Print(err)
if err := flags.SetFlagRequired(deleteCmd, "store-id", "cmd/store/delete", false); err != nil {
fmt.Printf("%v\n", err)
os.Exit(1)
}
}
6 changes: 3 additions & 3 deletions cmd/store/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/openfga/cli/internal/cmdutils"
"github.com/openfga/cli/internal/confirmation"
"github.com/openfga/cli/internal/fga"
"github.com/openfga/cli/internal/flags"
"github.com/openfga/cli/internal/output"
"github.com/openfga/cli/internal/storetest"
"github.com/openfga/cli/internal/tuple"
Expand Down Expand Up @@ -193,9 +194,8 @@ func init() {
exportCmd.Flags().String("model-id", "", "Authorization Model ID")
exportCmd.Flags().Uint("max-tuples", defaultMaxTupleCount, "max number of tuples to return in the output")

err := exportCmd.MarkFlagRequired("store-id")
if err != nil {
fmt.Print(err)
if err := flags.SetFlagRequired(exportCmd, "store-id", "cmd/store/export", false); err != nil {
fmt.Printf("%v\n", err)
os.Exit(1)
}
}
6 changes: 3 additions & 3 deletions cmd/store/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (

"github.com/openfga/cli/internal/cmdutils"
"github.com/openfga/cli/internal/fga"
"github.com/openfga/cli/internal/flags"
"github.com/openfga/cli/internal/output"
)

Expand Down Expand Up @@ -65,9 +66,8 @@ var getCmd = &cobra.Command{
func init() {
getCmd.Flags().String("store-id", "", "Store ID")

err := getCmd.MarkFlagRequired("store-id")
if err != nil {
fmt.Print(err)
if err := flags.SetFlagRequired(getCmd, "store-id", "cmd/store/get", false); err != nil {
fmt.Printf("%v\n", err)
os.Exit(1)
}
}
5 changes: 3 additions & 2 deletions cmd/store/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/openfga/cli/internal/authorizationmodel"
"github.com/openfga/cli/internal/cmdutils"
"github.com/openfga/cli/internal/fga"
"github.com/openfga/cli/internal/flags"
"github.com/openfga/cli/internal/output"
"github.com/openfga/cli/internal/storetest"
"github.com/openfga/cli/internal/tuple"
Expand Down Expand Up @@ -339,8 +340,8 @@ func init() {
importCmd.Flags().Int("max-tuples-per-write", tuple.MaxTuplesPerWrite, "Max tuples per write chunk.")
importCmd.Flags().Int("max-parallel-requests", tuple.MaxParallelRequests, "Max number of requests to issue to the server in parallel.") //nolint:lll

if err := importCmd.MarkFlagRequired("file"); err != nil {
fmt.Printf("error setting flag as required - %v: %v\n", "cmd/models/write", err)
if err := flags.SetFlagRequired(importCmd, "file", "cmd/store/import", false); err != nil {
fmt.Printf("%v\n", err)
os.Exit(1)
}
}
7 changes: 4 additions & 3 deletions cmd/tuple/tuple.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"os"

"github.com/spf13/cobra"

"github.com/openfga/cli/internal/flags"
)

// TupleCmd represents the tuple command.
Expand All @@ -39,9 +41,8 @@ func init() {

TupleCmd.PersistentFlags().String("store-id", "", "Store ID")

err := TupleCmd.MarkPersistentFlagRequired("store-id")
if err != nil { //nolint:wsl
fmt.Print(err)
if err := flags.SetFlagRequired(TupleCmd, "store-id", "cmd/tuple/tuple", true); err != nil {
fmt.Printf("%v\n", err)
Copy link

Copilot AI Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Error messages should be written to stderr rather than stdout. Consider using fmt.Fprintln(os.Stderr, err) to ensure errors go to the correct output stream.

Suggested change
fmt.Printf("%v\n", err)
fmt.Fprintln(os.Stderr, err)

Copilot uses AI. Check for mistakes.
os.Exit(1)
}
}
101 changes: 101 additions & 0 deletions internal/flags/flags.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Package flags provides utility functions for working with cobra command flags.
// It simplifies the process of marking flags as required and handling related errors.
package flags

import (
"errors"
"fmt"
"strings"

"github.com/spf13/cobra"
)

var (
// ErrFlagRequired is returned when a flag cannot be marked as required.
ErrFlagRequired = errors.New("error setting flag as required")

// ErrInvalidInput is returned when invalid input is provided.
ErrInvalidInput = errors.New("invalid input")
)

// buildFlagRequiredError creates a consistent error message for flag requirement failures.
// It wraps the original error with context about which flag and location failed.
func buildFlagRequiredError(flag, location string, err error) error {
if err == nil {
return nil
}

return fmt.Errorf("%w - (flag: %s, file: %s): %w", ErrFlagRequired, flag, location, err)
}

// SetFlagRequired marks a single flag as required for a cobra command.
//
// Parameters:
// - cmd: The cobra command to modify
// - flag: The name of the flag to mark as required
// - location: A string identifying the calling location (for error context)
// - isPersistent: If true, marks the persistent flag as required; otherwise marks the regular flag
//
// Returns an error if:
// - cmd is nil
// - flag is empty
// - the flag cannot be marked as required (e.g., flag doesn't exist)
func SetFlagRequired(cmd *cobra.Command, flag string, location string, isPersistent bool) error {
if cmd == nil {
return fmt.Errorf("%w: command cannot be nil", ErrInvalidInput)
}

if strings.TrimSpace(flag) == "" {
return fmt.Errorf("%w: flag name cannot be empty", ErrInvalidInput)
}

if isPersistent {
if err := cmd.MarkPersistentFlagRequired(flag); err != nil {
return buildFlagRequiredError(flag, location, err)
}
} else {
if err := cmd.MarkFlagRequired(flag); err != nil {
return buildFlagRequiredError(flag, location, err)
}
}

return nil
}

// SetFlagsRequired marks multiple flags as required for a cobra command.
//
// Parameters:
// - cmd: The cobra command to modify
// - flags: A slice of flag names to mark as required
// - location: A string identifying the calling location (for error context)
// - isPersistent: If true, marks the persistent flags as required; otherwise marks the regular flags
//
// Returns a joined error containing all individual flag requirement failures.
// If no flags are provided or all succeed, returns nil.
//
// Note: This function continues processing all flags even if some fail,
// allowing you to see all failures at once rather than stopping at the first error.
func SetFlagsRequired(cmd *cobra.Command, flags []string, location string, isPersistent bool) error {
if cmd == nil {
return fmt.Errorf("%w: command cannot be nil", ErrInvalidInput)
}

if len(flags) == 0 {
return nil
}

// Pre-allocate slice with exact capacity needed
flagErrors := make([]error, 0, len(flags))

for _, flag := range flags {
if err := SetFlagRequired(cmd, flag, location, isPersistent); err != nil {
flagErrors = append(flagErrors, err)
}
}

if len(flagErrors) > 0 {
return errors.Join(flagErrors...)
}

return nil
}
Loading
Loading