From bb043572d3ef59e11497e4ae1833ca6001d9c6e6 Mon Sep 17 00:00:00 2001 From: CHEN Li Date: Fri, 13 Mar 2026 11:36:00 +0800 Subject: [PATCH] enhancement for supporting local ssh config, so you can use alias in ssh:// --- cmd/gost/program.go | 2 +- cmd/gost/sshconfig.go | 193 +++++++++++++++++++++++++++++++ cmd/gost/sshconfig_test.go | 226 +++++++++++++++++++++++++++++++++++++ cmd/gost/version.go | 2 +- 4 files changed, 421 insertions(+), 2 deletions(-) create mode 100644 cmd/gost/sshconfig.go create mode 100644 cmd/gost/sshconfig_test.go diff --git a/cmd/gost/program.go b/cmd/gost/program.go index 7ce007e..6a9f9fe 100644 --- a/cmd/gost/program.go +++ b/cmd/gost/program.go @@ -36,7 +36,7 @@ func (p *program) Init(env svc.Environment) error { parser.Init(parser.Args{ CfgFile: cfgFile, Services: services, - Nodes: nodes, + Nodes: expandSSHNodes(nodes), Debug: debug, Trace: trace, ApiAddr: apiAddr, diff --git a/cmd/gost/sshconfig.go b/cmd/gost/sshconfig.go new file mode 100644 index 0000000..072fdb1 --- /dev/null +++ b/cmd/gost/sshconfig.go @@ -0,0 +1,193 @@ +package main + +import ( + "bufio" + "net" + "net/url" + "os" + "path/filepath" + "strings" +) + +// sshHostEntry holds the resolved settings for a single SSH host alias. +type sshHostEntry struct { + hostname string // HostName directive + port string // Port directive + user string // User directive + identityFile string // first IdentityFile directive +} + +// parseSSHConfigLine splits an SSH config line into (key, value). +// Handles both "Key Value" and "Key=Value" (with optional spaces around =). +func parseSSHConfigLine(line string) (key, value string, ok bool) { + line = strings.TrimSpace(line) + if line == "" || strings.HasPrefix(line, "#") { + return "", "", false + } + // Find first whitespace or '=' + i := 0 + for i < len(line) && line[i] != ' ' && line[i] != '\t' && line[i] != '=' { + i++ + } + if i == len(line) { + return "", "", false + } + key = line[:i] + rest := strings.TrimLeft(line[i:], " \t=") + // Strip trailing inline comment (must be preceded by whitespace) + if idx := strings.Index(rest, " #"); idx >= 0 { + rest = strings.TrimSpace(rest[:idx]) + } + value = strings.Trim(rest, `"'`) + return key, value, value != "" +} + +// readSSHConfig parses ~/.ssh/config and returns a map of host alias/pattern +// to resolved entry. Only exact-match patterns (no wildcards) are indexed so +// they can be looked up directly; wildcard patterns are skipped for now. +func readSSHConfig() map[string]*sshHostEntry { + home, err := os.UserHomeDir() + if err != nil { + return nil + } + path := filepath.Join(home, ".ssh", "config") + f, err := os.Open(path) + if err != nil { + return nil + } + defer f.Close() + + entries := make(map[string]*sshHostEntry) + var current *sshHostEntry + var currentPatterns []string + + commit := func() { + if current == nil { + return + } + for _, pat := range currentPatterns { + // Skip wildcard patterns — they can't be used for exact lookup. + if strings.ContainsAny(pat, "*?") { + continue + } + if _, exists := entries[pat]; !exists { + entries[pat] = current + } + } + } + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + key, value, ok := parseSSHConfigLine(scanner.Text()) + if !ok { + continue + } + switch strings.ToLower(key) { + case "host": + commit() + currentPatterns = strings.Fields(value) + current = &sshHostEntry{} + case "hostname": + if current != nil && current.hostname == "" { + current.hostname = value + } + case "port": + if current != nil && current.port == "" { + current.port = value + } + case "user": + if current != nil && current.user == "" { + current.user = value + } + case "identityfile": + if current != nil && current.identityFile == "" { + if strings.HasPrefix(value, "~/") { + home, _ := os.UserHomeDir() + value = filepath.Join(home, value[2:]) + } + current.identityFile = value + } + } + } + commit() + + return entries +} + +// expandSSHNode rewrites a single node URL string using ~/.ssh/config when the +// scheme is "ssh" and the host matches a Host entry. Fields already present in +// the URL are never overridden. +func expandSSHNode(raw string, cfg map[string]*sshHostEntry) string { + if cfg == nil || !strings.HasPrefix(raw, "ssh://") { + return raw + } + + u, err := url.Parse(raw) + if err != nil || u.Scheme != "ssh" { + return raw + } + + host := u.Hostname() + port := u.Port() + + entry, ok := cfg[host] + if !ok { + return raw + } + + // Resolve hostname. + newHost := host + if entry.hostname != "" { + newHost = entry.hostname + } + + // Resolve port: only substitute when the URL carries no explicit port or + // carries the SSH default (22) and the config specifies a different one. + newPort := port + if newPort == "" || newPort == "22" { + if entry.port != "" && entry.port != "22" { + newPort = entry.port + } + } + + if newPort != "" && newPort != "22" { + u.Host = net.JoinHostPort(newHost, newPort) + } else { + u.Host = newHost + } + + // Apply user if not already present in the URL. + if entry.user != "" && (u.User == nil || u.User.Username() == "") { + var pw string + if u.User != nil { + pw, _ = u.User.Password() + } + if pw != "" { + u.User = url.UserPassword(entry.user, pw) + } else { + u.User = url.User(entry.user) + } + } + + // Apply identity file as the "key" query parameter when not already set. + if entry.identityFile != "" { + q := u.Query() + if q.Get("privateKeyFile") == "" { + q.Set("privateKeyFile", entry.identityFile) + u.RawQuery = q.Encode() + } + } + + return u.String() +} + +// expandSSHNodes applies expandSSHNode to every element of the node list and +// returns the result as a new slice. The original slice is not modified. +func expandSSHNodes(nodes []string) []string { + cfg := readSSHConfig() + result := make([]string, len(nodes)) + for i, n := range nodes { + result[i] = expandSSHNode(n, cfg) + } + return result +} diff --git a/cmd/gost/sshconfig_test.go b/cmd/gost/sshconfig_test.go new file mode 100644 index 0000000..ba78885 --- /dev/null +++ b/cmd/gost/sshconfig_test.go @@ -0,0 +1,226 @@ +package main + +import ( + "net/url" + "os" + "path/filepath" + "testing" +) + +func writeSSHConfig(t *testing.T, content string) (cleanup func()) { + t.Helper() + dir := t.TempDir() + sshDir := filepath.Join(dir, ".ssh") + if err := os.Mkdir(sshDir, 0700); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(sshDir, "config"), []byte(content), 0600); err != nil { + t.Fatal(err) + } + old := os.Getenv("HOME") + os.Setenv("HOME", dir) + return func() { os.Setenv("HOME", old) } +} + +func TestParseSSHConfigLine(t *testing.T) { + cases := []struct { + line, wantKey, wantVal string + wantOk bool + }{ + {" HostName example.com", "HostName", "example.com", true}, + {"HostName=example.com", "HostName", "example.com", true}, + {"HostName = example.com", "HostName", "example.com", true}, + {"# comment", "", "", false}, + {"", "", "", false}, + {`IdentityFile "~/.ssh/id_rsa"`, "IdentityFile", "~/.ssh/id_rsa", true}, + {"Host myalias", "Host", "myalias", true}, + {"Host alias1 alias2", "Host", "alias1 alias2", true}, + } + for _, c := range cases { + k, v, ok := parseSSHConfigLine(c.line) + if ok != c.wantOk || k != c.wantKey || v != c.wantVal { + t.Errorf("parseSSHConfigLine(%q) = (%q, %q, %v), want (%q, %q, %v)", + c.line, k, v, ok, c.wantKey, c.wantVal, c.wantOk) + } + } +} + +func TestExpandSSHNode_NoMatch(t *testing.T) { + cfg := map[string]*sshHostEntry{ + "myalias": {hostname: "real.host.com", port: "2222", user: "alice"}, + } + if got := expandSSHNode("http://myalias:8080", cfg); got != "http://myalias:8080" { + t.Errorf("non-SSH URL was modified: %q", got) + } + if got := expandSSHNode("ssh://other:22", cfg); got != "ssh://other:22" { + t.Errorf("unknown SSH host was modified: %q", got) + } +} + +func TestExpandSSHNode_FullExpansion(t *testing.T) { + keyPath := "/home/alice/.ssh/id_ed25519" + cfg := map[string]*sshHostEntry{ + "myalias": { + hostname: "real.host.com", + port: "2222", + user: "alice", + identityFile: keyPath, + }, + } + + got := expandSSHNode("ssh://myalias", cfg) + u, err := url.Parse(got) + if err != nil { + t.Fatalf("parse result URL: %v", err) + } + if u.Hostname() != "real.host.com" { + t.Errorf("hostname: got %q, want %q", u.Hostname(), "real.host.com") + } + if u.Port() != "2222" { + t.Errorf("port: got %q, want %q", u.Port(), "2222") + } + if u.User.Username() != "alice" { + t.Errorf("user: got %q, want %q", u.User.Username(), "alice") + } + if q := u.Query().Get("privateKeyFile"); q != keyPath { + t.Errorf("privateKeyFile param: got %q, want %q", q, keyPath) + } +} + +func TestExpandSSHNode_PreserveExisting(t *testing.T) { + cfg := map[string]*sshHostEntry{ + "myalias": { + hostname: "real.host.com", + port: "2222", + user: "alice", + identityFile: "/home/alice/.ssh/id_ed25519", + }, + } + + got := expandSSHNode("ssh://bob@myalias:22?key=/tmp/other_key", cfg) + u, err := url.Parse(got) + if err != nil { + t.Fatalf("parse result URL: %v", err) + } + if u.Hostname() != "real.host.com" { + t.Errorf("hostname: got %q, want %q", u.Hostname(), "real.host.com") + } + // Port 22 is default; config has 2222 → should be substituted. + if u.Port() != "2222" { + t.Errorf("port: got %q, want %q", u.Port(), "2222") + } + // Existing user in URL must be preserved. + if u.User.Username() != "bob" { + t.Errorf("user: got %q, want %q", u.User.Username(), "bob") + } + // Existing TLS key param must be preserved unchanged. + if q := u.Query().Get("key"); q != "/tmp/other_key" { + t.Errorf("key param: got %q, want %q", q, "/tmp/other_key") + } + // privateKeyFile from config must be added (key is TLS, not SSH). + if q := u.Query().Get("privateKeyFile"); q != "/home/alice/.ssh/id_ed25519" { + t.Errorf("privateKeyFile param: got %q, want %q", q, "/home/alice/.ssh/id_ed25519") + } +} + +func TestExpandSSHNode_PrivateKeyFileInURL_Preserved(t *testing.T) { + cfg := map[string]*sshHostEntry{ + "srv": {hostname: "1.2.3.4", identityFile: "/home/user/.ssh/id_rsa"}, + } + got := expandSSHNode("ssh://srv?privateKeyFile=/custom/key", cfg) + u, _ := url.Parse(got) + // privateKeyFile already set → identityFile from config should NOT override. + if q := u.Query().Get("privateKeyFile"); q != "/custom/key" { + t.Errorf("privateKeyFile param: got %q, want %q", q, "/custom/key") + } +} + +func TestReadSSHConfig(t *testing.T) { + cleanup := writeSSHConfig(t, ` +# global defaults +Host * + ServerAliveInterval 60 + +Host bastion + HostName bastion.example.com + Port 2222 + User deploy + IdentityFile ~/.ssh/bastion_key + +Host dev + HostName 10.0.0.5 + User dev +`) + defer cleanup() + + cfg := readSSHConfig() + if cfg == nil { + t.Fatal("readSSHConfig returned nil") + } + + e, ok := cfg["bastion"] + if !ok { + t.Fatal("expected 'bastion' entry") + } + if e.hostname != "bastion.example.com" { + t.Errorf("hostname: got %q", e.hostname) + } + if e.port != "2222" { + t.Errorf("port: got %q", e.port) + } + if e.user != "deploy" { + t.Errorf("user: got %q", e.user) + } + if filepath.Base(e.identityFile) != "bastion_key" { + t.Errorf("identityFile: got %q", e.identityFile) + } + + dev, ok := cfg["dev"] + if !ok { + t.Fatal("expected 'dev' entry") + } + if dev.hostname != "10.0.0.5" { + t.Errorf("dev hostname: got %q", dev.hostname) + } + + // Wildcard Host * must not be indexed by exact key. + if _, ok := cfg["*"]; ok { + t.Error("wildcard Host * should not appear in entries map") + } +} + +func TestExpandSSHNodes_Integration(t *testing.T) { + cleanup := writeSSHConfig(t, ` +Host jump + HostName jump.corp.example.com + Port 2200 + User ops + IdentityFile ~/.ssh/jump_key +`) + defer cleanup() + + nodes := []string{ + "ssh://jump", + "http://proxy:8080", + "ssh://realhost:22", + } + got := expandSSHNodes(nodes) + + u0, _ := url.Parse(got[0]) + if u0.Hostname() != "jump.corp.example.com" { + t.Errorf("[0] hostname: got %q", u0.Hostname()) + } + if u0.Port() != "2200" { + t.Errorf("[0] port: got %q", u0.Port()) + } + if u0.User.Username() != "ops" { + t.Errorf("[0] user: got %q", u0.User.Username()) + } + if got[1] != "http://proxy:8080" { + t.Errorf("[1] non-SSH URL changed: %q", got[1]) + } + // ssh://realhost:22 — not in config, unchanged. + if got[2] != "ssh://realhost:22" { + t.Errorf("[2] unknown host changed: %q", got[2]) + } +} diff --git a/cmd/gost/version.go b/cmd/gost/version.go index b5a6251..2c633c8 100644 --- a/cmd/gost/version.go +++ b/cmd/gost/version.go @@ -1,5 +1,5 @@ package main var ( - version = "3.2.6" + version = "3.2.7" )