Skip to content

Commit c047a98

Browse files
authored
Merge pull request #2 from soulteary/feat/smart-scan
fear: smart config scanner
2 parents 326287a + 541b993 commit c047a98

File tree

7 files changed

+880
-38
lines changed

7 files changed

+880
-38
lines changed

internal/define/define.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,13 @@ type YAMLOutput struct {
3232
Default map[string]string `yaml:"default,omitempty"`
3333
Groups map[string]GroupConfig `yaml:",inline"`
3434
}
35+
36+
var ExcludePatterns = []string{
37+
"known_hosts",
38+
"authorized_keys",
39+
"*.pub",
40+
"id_*",
41+
"*.key",
42+
"*.pem",
43+
"*.ppk",
44+
}

internal/fn/fn.go

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -99,33 +99,19 @@ func DetectStringType(input string) string {
9999
}
100100

101101
func GetPathContent(src string) ([]byte, error) {
102-
srcInfo, err := os.Stat(src)
102+
configFiles, err := ReadSSHConfigs(src)
103103
if err != nil {
104-
return nil, fmt.Errorf("can not get source info: %v", err)
104+
return nil, err
105+
}
106+
if len(configFiles.Configs) == 0 {
107+
return nil, fmt.Errorf("no valid SSH config found in %s", src)
105108
}
106109

107110
var content []byte
108-
109-
if srcInfo.IsDir() {
110-
files, err := os.ReadDir(src)
111-
if err != nil {
112-
return nil, fmt.Errorf("can not read source directory: %v", err)
113-
}
114-
115-
for _, file := range files {
116-
if !file.IsDir() {
117-
filePath := filepath.Join(src, file.Name())
118-
fileContent, err := os.ReadFile(filePath)
119-
if err != nil {
120-
return nil, fmt.Errorf("can not read file %s: %v", filePath, err)
121-
}
122-
content = append(content, fileContent...)
123-
}
124-
}
125-
} else {
126-
content, err = os.ReadFile(src)
127-
if err != nil {
128-
return nil, fmt.Errorf("can not read source file: %v", err)
111+
for filePath := range configFiles.Configs {
112+
fileContent, err := os.ReadFile(filePath)
113+
if err == nil {
114+
content = append(content, fileContent...)
129115
}
130116
}
131117
return content, nil

internal/fn/fn_test.go

Lines changed: 114 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"testing"
1313

1414
Define "github.com/soulteary/ssh-config/internal/define"
15+
"github.com/soulteary/ssh-config/internal/fn"
1516
Fn "github.com/soulteary/ssh-config/internal/fn"
1617
)
1718

@@ -478,8 +479,8 @@ func TestGetPathContent(t *testing.T) {
478479

479480
file1 := filepath.Join(multiDir, "file1.txt")
480481
file2 := filepath.Join(multiDir, "file2.txt")
481-
content1 := []byte("Content of file 1")
482-
content2 := []byte("Content of file 2")
482+
content1 := []byte("Host test1")
483+
content2 := []byte("Host test2")
483484

484485
err = os.WriteFile(file1, content1, 0644)
485486
if err != nil {
@@ -532,7 +533,7 @@ func TestGetPathContent(t *testing.T) {
532533

533534
_, err = Fn.GetPathContent(dirWithUnreadableFile)
534535
if err == nil {
535-
t.Error("Expected error for directory with unreadable file, got nil")
536+
t.Error("Expected error for no valid SSH config found in, got nil", err)
536537
}
537538

538539
unreadableFile2 := filepath.Join(tempDir, "unreadable_single.txt")
@@ -550,8 +551,31 @@ func TestGetPathContent(t *testing.T) {
550551
_, err = Fn.GetPathContent(unreadableFile2)
551552
if err == nil {
552553
t.Error("Expected error for unreadable single file, got nil")
553-
} else if !strings.Contains(err.Error(), "can not read source file") {
554-
t.Errorf("Expected error message to contain 'can not read source file', got: %v", err)
554+
} else if !strings.Contains(err.Error(), "no valid SSH config found in") {
555+
t.Errorf("Expected error message to contain 'no valid SSH config found in', got: %v", err)
556+
}
557+
558+
dirWithCorruptFile := filepath.Join(tempDir, "dir_with_corrupt")
559+
err = os.Mkdir(dirWithCorruptFile, 0755)
560+
if err != nil {
561+
t.Fatalf("Failed to create dir_with_corrupt: %v", err)
562+
}
563+
564+
normalFile := filepath.Join(dirWithCorruptFile, "normal.txt")
565+
err = os.WriteFile(normalFile, []byte("Normal content"), 0644)
566+
if err != nil {
567+
t.Fatalf("Failed to create normal file: %v", err)
568+
}
569+
570+
corruptFile := filepath.Join(dirWithCorruptFile, "corrupt.txt")
571+
err = os.Symlink("/nonexistent/file", corruptFile)
572+
if err != nil {
573+
t.Fatalf("Failed to create corrupt file: %v", err)
574+
}
575+
576+
_, err = Fn.GetPathContent(dirWithCorruptFile)
577+
if err == nil {
578+
t.Fatalf("Expected error for directory with corrupt file, got nil")
555579
}
556580
}
557581

@@ -752,3 +776,88 @@ func TestTidyLastEmptyLines(t *testing.T) {
752776
})
753777
}
754778
}
779+
780+
//
781+
782+
// SSHConfig 模拟原始结构
783+
type SSHConfig struct {
784+
Configs map[string]interface{}
785+
}
786+
787+
// 创建测试用的配置文件目录
788+
func createTestConfigDir(t *testing.T) (string, error) {
789+
tmpDir := t.TempDir()
790+
791+
// 创建测试文件1
792+
test1Path := filepath.Join(tmpDir, "test1.txt")
793+
err := os.WriteFile(test1Path, []byte("Host abc"), 0644)
794+
if err != nil {
795+
return "", err
796+
}
797+
798+
// 创建测试文件2
799+
test2Path := filepath.Join(tmpDir, "test2.txt")
800+
err = os.WriteFile(test2Path, []byte("Host def"), 0644)
801+
if err != nil {
802+
return "", err
803+
}
804+
805+
// 创建一个无权限的文件
806+
noPermFile := filepath.Join(tmpDir, "no_perm.txt")
807+
err = os.WriteFile(noPermFile, []byte("no permission"), 0644)
808+
if err != nil {
809+
return "", err
810+
}
811+
err = os.Chmod(noPermFile, 0000) // 移除所有权限
812+
if err != nil {
813+
return "", err
814+
}
815+
816+
return tmpDir, nil
817+
}
818+
819+
func TestGetPathContent2(t *testing.T) {
820+
// 测试场景1: 成功读取文件
821+
t.Run("Success case", func(t *testing.T) {
822+
tmpDir, err := createTestConfigDir(t)
823+
if err != nil {
824+
t.Fatalf("Failed to create test directory: %v", err)
825+
}
826+
827+
content, err := fn.GetPathContent(tmpDir)
828+
829+
// 验证无权限文件的内容没有被包含
830+
if strings.Contains(string(content), "no permission") {
831+
t.Error("Content should not contain 'no permission' as the file is not readable")
832+
}
833+
})
834+
835+
// 测试场景2: 配置目录不存在
836+
t.Run("Non-existent directory", func(t *testing.T) {
837+
content, err := fn.GetPathContent("non_existent_dir")
838+
if err == nil {
839+
t.Error("Expected an error, got nil")
840+
}
841+
if content != nil {
842+
t.Error("Expected nil content")
843+
}
844+
})
845+
846+
// 测试场景3: 文件读取失败(权限问题)
847+
t.Run("File read error due to permissions", func(t *testing.T) {
848+
tmpDir, err := createTestConfigDir(t)
849+
if err != nil {
850+
t.Fatalf("Failed to create test directory: %v", err)
851+
}
852+
853+
content, err := fn.GetPathContent(tmpDir)
854+
if err != nil {
855+
t.Errorf("Expected no error, got %v", err)
856+
}
857+
858+
// 验证无权限文件的内容没有被包含
859+
if strings.Contains(string(content), "no permission") {
860+
t.Error("Content should not contain 'no permission' as the file is not readable")
861+
}
862+
})
863+
}

internal/fn/scanner.go

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
package fn
2+
3+
import (
4+
"bufio"
5+
"fmt"
6+
"os"
7+
"path/filepath"
8+
"strings"
9+
10+
"github.com/soulteary/ssh-config/internal/define"
11+
)
12+
13+
type ConfigFile struct {
14+
Path string
15+
Content []string
16+
Hosts map[string]map[string]string
17+
}
18+
19+
type SSHConfig struct {
20+
Configs map[string]*ConfigFile // key: 配置文件路径
21+
}
22+
23+
func IsExcluded(filename string) bool {
24+
filename = strings.ToLower(filename)
25+
26+
for _, pattern := range define.ExcludePatterns {
27+
if matched, _ := filepath.Match(pattern, filename); matched {
28+
return true
29+
}
30+
}
31+
32+
return false
33+
}
34+
35+
func IsConfigFile(path string) bool {
36+
// read file first few lines to determine if it's SSH config file format
37+
file, err := os.Open(path)
38+
if err != nil {
39+
return false
40+
}
41+
defer file.Close()
42+
43+
scanner := bufio.NewScanner(file)
44+
lineCount := 0
45+
validLines := 0
46+
47+
// check first 5 lines
48+
for scanner.Scan() && lineCount < 5 {
49+
line := strings.TrimSpace(scanner.Text())
50+
lineCount++
51+
52+
if line == "" || strings.HasPrefix(line, "#") {
53+
continue
54+
}
55+
56+
parts := strings.Fields(line)
57+
if len(parts) >= 2 {
58+
key := strings.ToLower(parts[0])
59+
switch key {
60+
case "host", "hostname", "user", "port", "identityfile", "proxycommand":
61+
validLines++
62+
}
63+
}
64+
}
65+
66+
return validLines > 0
67+
}
68+
69+
func ReadSSHConfigs(sshPath string) (*SSHConfig, error) {
70+
config := &SSHConfig{
71+
Configs: make(map[string]*ConfigFile),
72+
}
73+
74+
info, err := os.Stat(sshPath)
75+
if err != nil {
76+
return nil, fmt.Errorf("failed to get path object info: %v", err)
77+
}
78+
79+
if !info.IsDir() {
80+
configFile := ReadSingleConfig(sshPath)
81+
if configFile != nil {
82+
config.Configs[sshPath] = configFile
83+
}
84+
return config, nil
85+
}
86+
87+
err = filepath.Walk(sshPath, func(path string, info os.FileInfo, err error) error {
88+
if err != nil {
89+
return err
90+
}
91+
92+
if info.IsDir() {
93+
return nil
94+
}
95+
96+
if IsExcluded(info.Name()) {
97+
return nil
98+
}
99+
100+
if !IsConfigFile(path) {
101+
return nil
102+
}
103+
104+
configFile := ReadSingleConfig(path)
105+
if configFile != nil {
106+
config.Configs[path] = configFile
107+
}
108+
return nil
109+
})
110+
111+
if err != nil {
112+
return nil, fmt.Errorf("failed to walk directory: %v", err)
113+
}
114+
115+
return config, nil
116+
}
117+
118+
func ReadSingleConfig(path string) *ConfigFile {
119+
file, err := os.Open(path)
120+
if err != nil {
121+
return nil
122+
}
123+
defer file.Close()
124+
125+
config := &ConfigFile{
126+
Path: path,
127+
Hosts: make(map[string]map[string]string),
128+
}
129+
130+
scanner := bufio.NewScanner(file)
131+
var currentHost string
132+
var content []string
133+
134+
for scanner.Scan() {
135+
line := strings.TrimSpace(scanner.Text())
136+
content = append(content, line)
137+
138+
if line == "" || strings.HasPrefix(line, "#") {
139+
continue
140+
}
141+
142+
parts := strings.Fields(line)
143+
if len(parts) == 2 {
144+
key := strings.ToLower(parts[0])
145+
value := strings.Join(parts[1:], " ")
146+
147+
if key == "host" {
148+
currentHost = value
149+
config.Hosts[currentHost] = make(map[string]string)
150+
} else if currentHost != "" {
151+
config.Hosts[currentHost][key] = value
152+
}
153+
}
154+
}
155+
156+
if err := scanner.Err(); err != nil {
157+
return nil
158+
}
159+
160+
config.Content = content
161+
return config
162+
}
163+
164+
func (c *SSHConfig) GetHostConfig(host string) map[string]map[string]string {
165+
results := make(map[string]map[string]string)
166+
167+
for path, config := range c.Configs {
168+
if hostConfig, exists := config.Hosts[host]; exists {
169+
results[path] = hostConfig
170+
}
171+
}
172+
173+
return results
174+
}
175+
176+
func (c *SSHConfig) PrintConfigs() {
177+
for path, config := range c.Configs {
178+
fmt.Printf("\n=== 配置文件: %s ===\n", path)
179+
for host, hostConfig := range config.Hosts {
180+
fmt.Printf("\nHost %s:\n", host)
181+
for key, value := range hostConfig {
182+
fmt.Printf(" %s = %s\n", key, value)
183+
}
184+
}
185+
}
186+
}

0 commit comments

Comments
 (0)