Skip to content

Commit 5886102

Browse files
Primary branch flag for scan create command (AST-102468) (#1207)
Set branch as primary using branch primary flag
1 parent eb4c35d commit 5886102

File tree

6 files changed

+84
-13
lines changed

6 files changed

+84
-13
lines changed

internal/commands/scan.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ const (
118118
ScsRepoWarningMsg = "SCS scan warning: Unable to start Scorecard scan due to missing required flags, please include in the ast-cli arguments: " +
119119
"--scs-repo-url your_repo_url --scs-repo-token your_repo_token"
120120
ScsScorecardUnsupportedHostWarningMsg = "SCS scan warning: Unable to run Scorecard scanner due to unsupported repo host. Currently, Scorecard can only run on GitHub Cloud repos."
121+
BranchPrimaryPrefix = "--branch-primary="
121122
)
122123

123124
var (
@@ -722,6 +723,11 @@ func scanCreateSubCommand(
722723
"Enable SAST scan using light query configuration",
723724
)
724725

726+
createScanCmd.PersistentFlags().Bool(
727+
commonParams.BranchPrimaryFlag,
728+
false,
729+
"This flag sets the branch specified in --branch as the PRIMARY branch for the project")
730+
725731
createScanCmd.PersistentFlags().Bool(
726732
commonParams.SastRecommendedExclusionsFlags,
727733
false,
@@ -845,6 +851,7 @@ func setupScanTypeProjectAndConfig(
845851
userAllowedEngines, _ := jwtWrapper.GetAllowedEngines(featureFlagsWrapper)
846852
var info map[string]interface{}
847853
newProjectName, _ := cmd.Flags().GetString(commonParams.ProjectName)
854+
848855
_ = json.Unmarshal(*input, &info)
849856
info[resultsMapType] = getUploadType(cmd)
850857
// Handle the project settings
@@ -3006,6 +3013,15 @@ func validateCreateScanFlags(cmd *cobra.Command) error {
30063013
return fmt.Errorf("Invalid value for --%s flag. Must be a valid UUID.", commonParams.IacsPresetIDFlag)
30073014
}
30083015
}
3016+
// check if flag was passed as arg
3017+
isBranchChanged := cmd.Flags().Changed(commonParams.BranchPrimaryFlag)
3018+
if isBranchChanged {
3019+
for _, a := range os.Args[1:] {
3020+
if strings.HasPrefix(a, BranchPrimaryPrefix) {
3021+
return fmt.Errorf("invalid value for --branch-primary flag. This flag is sent without any values")
3022+
}
3023+
}
3024+
}
30093025

30103026
return nil
30113027
}

internal/commands/scan_test.go

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ const (
6363
SCSScoreCardError = "SCS scan failed to start: Scorecard scan is missing required flags, please include in the ast-cli arguments: " +
6464
"--scs-repo-url your_repo_url --scs-repo-token your_repo_token"
6565
outputFileName = "test_output.log"
66-
noUpdatesForExistingProject = "No tags to update. Skipping project update."
66+
noUpdatesForExistingProject = "No tags or branch to update. Skipping project update."
6767
ScaResolverZipNotSupportedErr = "Scanning Zip files is not supported by ScaResolver.Please use non-zip source"
6868
)
6969

@@ -400,6 +400,7 @@ func TestCreateScanBranches(t *testing.T) {
400400
// Bind cx_branch environment variable
401401
_ = viper.BindEnv("cx_branch", "CX_BRANCH")
402402
viper.SetDefault("cx_branch", "branch_from_environment_variable")
403+
assert.Equal(t, viper.GetString("cx_branch"), "branch_from_environment_variable")
403404

404405
// Test branch from environment variable. Since the cx_branch is bind the scan must run successfully without a branch flag defined
405406
execCmdNilAssertion(t, "scan", "create", "--project-name", "MOCK", "-s", dummyRepo)
@@ -637,6 +638,35 @@ func TestCreateScanResubmitWithScanTypes(t *testing.T) {
637638
execCmdNilAssertion(t, "scan", "create", "--project-name", "MOCK", "-s", dummyRepo, "-b", "dummy_branch", "--scan-types", "sast,iac-security,sca", "--debug", "--resubmit")
638639
}
639640

641+
func TestCreateScanWithPrimaryBranchFlag_Passed(t *testing.T) {
642+
execCmdNilAssertion(t, "scan", "create", "--project-name", "MOCK", "-s", dummyRepo, "-b", "dummy_branch", "--debug", "--branch-primary")
643+
}
644+
645+
func TestCreateScanWithPrimaryBranchFlagBooleanValueTrue_Failed(t *testing.T) {
646+
original := os.Args
647+
defer func() { os.Args = original }()
648+
os.Args = []string{
649+
"scan", "create", "--project-name", "MOCK", "-s", dummyRepo, "-b", "dummy_branch", "--debug", "--branch-primary=true",
650+
}
651+
err := execCmdNotNilAssertion(t, "scan", "create", "--project-name", "MOCK", "-s", dummyRepo, "-b", "dummy_branch", "--debug", "--branch-primary=true")
652+
assert.ErrorContains(t, err, "invalid value for --branch-primary flag", err.Error())
653+
}
654+
655+
func TestCreateScanWithPrimaryBranchFlagBooleanValueFalse_Failed(t *testing.T) {
656+
original := os.Args
657+
defer func() { os.Args = original }()
658+
os.Args = []string{
659+
"scan", "create", "--project-name", "MOCK", "-s", dummyRepo, "-b", "dummy_branch", "--debug", "--branch-primary=false",
660+
}
661+
err := execCmdNotNilAssertion(t, "scan", "create", "--project-name", "MOCK", "-s", dummyRepo, "-b", "dummy_branch", "--debug", "--branch-primary=false")
662+
assert.ErrorContains(t, err, "invalid value for --branch-primary flag", err.Error())
663+
}
664+
665+
func TestCreateScanWithPrimaryBranchFlagStringValue_Should_Fail(t *testing.T) {
666+
err := execCmdNotNilAssertion(t, "scan", "create", "--project-name", "MOCK", "-s", dummyRepo, "-b", "dummy_branch", "--debug", "--branch-primary=string")
667+
assert.ErrorContains(t, err, "invalid argument \"string\"", err.Error())
668+
}
669+
640670
func Test_parseThresholdSuccess(t *testing.T) {
641671
want := make(map[string]int)
642672
want["iac-security-low"] = 1
@@ -645,7 +675,6 @@ func Test_parseThresholdSuccess(t *testing.T) {
645675
t.Errorf("parseThreshold() = %v, want %v", got, want)
646676
}
647677
}
648-
649678
func Test_parseThresholdsSuccess(t *testing.T) {
650679
want := make(map[string]int)
651680
want["sast-high"] = 1
@@ -656,15 +685,13 @@ func Test_parseThresholdsSuccess(t *testing.T) {
656685
t.Errorf("parseThreshold() = %v, want %v", got, want)
657686
}
658687
}
659-
660688
func Test_parseThresholdParseError(t *testing.T) {
661689
want := make(map[string]int)
662690
threshold := " KICS - LoW=error"
663691
if got := parseThreshold(threshold); !reflect.DeepEqual(got, want) {
664692
t.Errorf("parseThreshold() = %v, want %v", got, want)
665693
}
666694
}
667-
668695
func TestCreateScanProjectTags(t *testing.T) {
669696
execCmdNilAssertion(t, scanCommand, "create", "--project-name", "MOCK", "-s", dummyRepo, "-b", "dummy_branch",
670697
"--project-tags", "test", "--debug")

internal/params/flags.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ const (
7070
NtlmProxyDomainFlag = "proxy-ntlm-domain"
7171
SastFastScanFlag = "sast-fast-scan"
7272
SastLightQueriesFlag = "sast-light-queries"
73+
BranchPrimaryFlag = "branch-primary"
7374
SastRecommendedExclusionsFlags = "sast-recommended-exclusions"
7475
NtlmProxyDomainFlagUsage = "Window domain when using NTLM proxy"
7576
BaseURIFlagUsage = "The base system URI"

internal/services/projects.go

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package services
22

33
import (
4+
"fmt"
45
"slices"
56
"strconv"
7+
"strings"
68
"time"
79

810
featureFlagsConstants "github.com/checkmarx/ast-cli/internal/constants/feature-flags"
@@ -11,6 +13,7 @@ import (
1113
"github.com/checkmarx/ast-cli/internal/wrappers"
1214
"github.com/pkg/errors"
1315
"github.com/spf13/cobra"
16+
"github.com/spf13/viper"
1417
)
1518

1619
const (
@@ -31,17 +34,19 @@ func FindProject(
3134
applicationWrapper wrappers.ApplicationsWrapper,
3235
featureFlagsWrapper wrappers.FeatureFlagsWrapper,
3336
) (string, error) {
37+
var isBranchPrimary bool
3438
resp, err := GetProjectsCollectionByProjectName(projectName, projectsWrapper)
3539
if err != nil {
3640
return "", err
3741
}
38-
42+
branchName := strings.TrimSpace(viper.GetString(commonParams.BranchKey))
43+
isBranchPrimary, _ = cmd.Flags().GetBool(commonParams.BranchPrimaryFlag)
3944
for i := 0; i < len(resp.Projects); i++ {
4045
project := resp.Projects[i]
4146
if project.Name == projectName {
4247
projectTags, _ := cmd.Flags().GetString(commonParams.ProjectTagList)
4348
projectPrivatePackage, _ := cmd.Flags().GetString(commonParams.ProjecPrivatePackageFlag)
44-
return updateProject(&project, projectsWrapper, projectTags, projectPrivatePackage)
49+
return updateProject(&project, projectsWrapper, projectTags, projectPrivatePackage, isBranchPrimary, branchName)
4550
}
4651
}
4752

@@ -55,7 +60,7 @@ func FindProject(
5560
}
5661

5762
projectID, err := createProject(projectName, cmd, projectsWrapper, groupsWrapper, accessManagementWrapper, applicationWrapper,
58-
applicationID, projectGroups, projectPrivatePackage, featureFlagsWrapper)
63+
applicationID, projectGroups, projectPrivatePackage, featureFlagsWrapper, isBranchPrimary, branchName)
5964
if err != nil {
6065
logger.PrintIfVerbose("error in creating project!")
6166
return "", err
@@ -97,12 +102,18 @@ func createProject(
97102
projectGroups string,
98103
projectPrivatePackage string,
99104
featureFlagsWrapper wrappers.FeatureFlagsWrapper,
105+
isBranchPrimary bool,
106+
branchName string,
100107
) (string, error) {
101108
projectTags, _ := cmd.Flags().GetString(commonParams.ProjectTagList)
102109
applicationName, _ := cmd.Flags().GetString(commonParams.ApplicationName)
103110
var projModel = wrappers.Project{}
104111
projModel.Name = projectName
105112
projModel.ApplicationIds = applicationID
113+
if isBranchPrimary {
114+
logger.PrintIfVerbose(fmt.Sprintf("Setting the branch in project : %s", branchName))
115+
projModel.MainBranch = branchName
116+
}
106117
var groupsMap []*wrappers.Group
107118
if projectGroups != "" {
108119
var groups []string
@@ -179,14 +190,20 @@ func verifyApplicationAssociationDone(applicationName, projectID string, applica
179190
//nolint:gocyclo
180191
func updateProject(project *wrappers.ProjectResponseModel,
181192
projectsWrapper wrappers.ProjectsWrapper,
182-
projectTags string, projectPrivatePackage string) (string, error) {
193+
projectTags string, projectPrivatePackage string, isBranchPrimary bool, branchName string) (string, error) {
183194
var projectID string
184195
var projModel = wrappers.Project{}
185196
projectID = project.ID
186-
projModel.MainBranch = project.MainBranch
197+
if isBranchPrimary {
198+
projModel.MainBranch = branchName
199+
logger.PrintfIfVerbose("Updating the branch as primary: %s", branchName)
200+
} else {
201+
projModel.MainBranch = project.MainBranch
202+
}
187203
projModel.RepoURL = project.RepoURL
188-
if projectTags == "" && projectPrivatePackage == "" {
189-
logger.PrintIfVerbose("No tags to update. Skipping project update.")
204+
205+
if projectTags == "" && projectPrivatePackage == "" && isBranchPrimary == false {
206+
logger.PrintIfVerbose("No tags or branch to update. Skipping project update.")
190207
return projectID, nil
191208
}
192209
if projectPrivatePackage != "" {

internal/services/projects_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ func Test_createProject(t *testing.T) {
166166
ttt.args.applicationID,
167167
ttt.args.projectGroups,
168168
ttt.args.projectPrivatePackage,
169-
ttt.args.featureFlagsWrapper)
169+
ttt.args.featureFlagsWrapper, false, "")
170170
if (err != nil) != ttt.wantErr {
171171
t.Errorf("createProject() error = %v, wantErr %v", err, ttt.wantErr)
172172
return
@@ -240,7 +240,7 @@ func Test_updateProject(t *testing.T) {
240240
ttt := tt
241241
t.Run(tt.name, func(t *testing.T) {
242242
got, err := updateProject(ttt.args.project, ttt.args.projectsWrapper,
243-
ttt.args.projectTags, ttt.args.projectPrivatePackage)
243+
ttt.args.projectTags, ttt.args.projectPrivatePackage, false, "")
244244
if (err != nil) != ttt.wantErr {
245245
t.Errorf("updateProject() error = %v, wantErr %v", err, ttt.wantErr)
246246
return

test/integration/scan_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,12 @@ func TestIncrementalScan(t *testing.T) {
654654
executeScanAssertions(t, projectIDInc, scanIDInc, map[string]string{})
655655
}
656656

657+
func TestBranchPrimaryFlag(t *testing.T) {
658+
projectName := getProjectNameForScanTests()
659+
scanID, projectID := createScanWithPrimaryBranchFlag(t, Dir, projectName, map[string]string{})
660+
executeScanAssertions(t, projectID, scanID, map[string]string{})
661+
}
662+
657663
// Start a scan guaranteed to take considerable time, cancel it and assert the status
658664
func TestCancelScan(t *testing.T) {
659665
scanID, _ := createScanSastNoWait(t, SlowRepo, map[string]string{})
@@ -969,6 +975,10 @@ func createScanScaWithResolver(
969975
)
970976
}
971977

978+
func createScanWithPrimaryBranchFlag(t *testing.T, source string, name string, tags map[string]string) (string, string) {
979+
return executeCreateScan(t, append(getCreateArgsWithName(source, tags, name, "sast,sca,iac-security"), "--branch-primary"))
980+
}
981+
972982
func createScanIncremental(t *testing.T, source string, name string, tags map[string]string) (string, string) {
973983
return executeCreateScan(t, append(getCreateArgsWithName(source, tags, name, "sast,sca,iac-security"), "--sast-incremental"))
974984
}

0 commit comments

Comments
 (0)