enhancement for supporting local ssh config, so you can use alias in ssh://<alias>

pull/849/head
CHEN Li 2026-03-13 11:36:00 +08:00
parent 340ba32ef0
commit bb043572d3
4 changed files with 421 additions and 2 deletions

View File

@ -36,7 +36,7 @@ func (p *program) Init(env svc.Environment) error {
parser.Init(parser.Args{
CfgFile: cfgFile,
Services: services,
Nodes: nodes,
Nodes: expandSSHNodes(nodes),
Debug: debug,
Trace: trace,
ApiAddr: apiAddr,

View File

@ -0,0 +1,193 @@
package main
import (
"bufio"
"net"
"net/url"
"os"
"path/filepath"
"strings"
)
// sshHostEntry holds the resolved settings for a single SSH host alias.
type sshHostEntry struct {
hostname string // HostName directive
port string // Port directive
user string // User directive
identityFile string // first IdentityFile directive
}
// parseSSHConfigLine splits an SSH config line into (key, value).
// Handles both "Key Value" and "Key=Value" (with optional spaces around =).
func parseSSHConfigLine(line string) (key, value string, ok bool) {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
return "", "", false
}
// Find first whitespace or '='
i := 0
for i < len(line) && line[i] != ' ' && line[i] != '\t' && line[i] != '=' {
i++
}
if i == len(line) {
return "", "", false
}
key = line[:i]
rest := strings.TrimLeft(line[i:], " \t=")
// Strip trailing inline comment (must be preceded by whitespace)
if idx := strings.Index(rest, " #"); idx >= 0 {
rest = strings.TrimSpace(rest[:idx])
}
value = strings.Trim(rest, `"'`)
return key, value, value != ""
}
// readSSHConfig parses ~/.ssh/config and returns a map of host alias/pattern
// to resolved entry. Only exact-match patterns (no wildcards) are indexed so
// they can be looked up directly; wildcard patterns are skipped for now.
func readSSHConfig() map[string]*sshHostEntry {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
path := filepath.Join(home, ".ssh", "config")
f, err := os.Open(path)
if err != nil {
return nil
}
defer f.Close()
entries := make(map[string]*sshHostEntry)
var current *sshHostEntry
var currentPatterns []string
commit := func() {
if current == nil {
return
}
for _, pat := range currentPatterns {
// Skip wildcard patterns — they can't be used for exact lookup.
if strings.ContainsAny(pat, "*?") {
continue
}
if _, exists := entries[pat]; !exists {
entries[pat] = current
}
}
}
scanner := bufio.NewScanner(f)
for scanner.Scan() {
key, value, ok := parseSSHConfigLine(scanner.Text())
if !ok {
continue
}
switch strings.ToLower(key) {
case "host":
commit()
currentPatterns = strings.Fields(value)
current = &sshHostEntry{}
case "hostname":
if current != nil && current.hostname == "" {
current.hostname = value
}
case "port":
if current != nil && current.port == "" {
current.port = value
}
case "user":
if current != nil && current.user == "" {
current.user = value
}
case "identityfile":
if current != nil && current.identityFile == "" {
if strings.HasPrefix(value, "~/") {
home, _ := os.UserHomeDir()
value = filepath.Join(home, value[2:])
}
current.identityFile = value
}
}
}
commit()
return entries
}
// expandSSHNode rewrites a single node URL string using ~/.ssh/config when the
// scheme is "ssh" and the host matches a Host entry. Fields already present in
// the URL are never overridden.
func expandSSHNode(raw string, cfg map[string]*sshHostEntry) string {
if cfg == nil || !strings.HasPrefix(raw, "ssh://") {
return raw
}
u, err := url.Parse(raw)
if err != nil || u.Scheme != "ssh" {
return raw
}
host := u.Hostname()
port := u.Port()
entry, ok := cfg[host]
if !ok {
return raw
}
// Resolve hostname.
newHost := host
if entry.hostname != "" {
newHost = entry.hostname
}
// Resolve port: only substitute when the URL carries no explicit port or
// carries the SSH default (22) and the config specifies a different one.
newPort := port
if newPort == "" || newPort == "22" {
if entry.port != "" && entry.port != "22" {
newPort = entry.port
}
}
if newPort != "" && newPort != "22" {
u.Host = net.JoinHostPort(newHost, newPort)
} else {
u.Host = newHost
}
// Apply user if not already present in the URL.
if entry.user != "" && (u.User == nil || u.User.Username() == "") {
var pw string
if u.User != nil {
pw, _ = u.User.Password()
}
if pw != "" {
u.User = url.UserPassword(entry.user, pw)
} else {
u.User = url.User(entry.user)
}
}
// Apply identity file as the "key" query parameter when not already set.
if entry.identityFile != "" {
q := u.Query()
if q.Get("privateKeyFile") == "" {
q.Set("privateKeyFile", entry.identityFile)
u.RawQuery = q.Encode()
}
}
return u.String()
}
// expandSSHNodes applies expandSSHNode to every element of the node list and
// returns the result as a new slice. The original slice is not modified.
func expandSSHNodes(nodes []string) []string {
cfg := readSSHConfig()
result := make([]string, len(nodes))
for i, n := range nodes {
result[i] = expandSSHNode(n, cfg)
}
return result
}

View File

@ -0,0 +1,226 @@
package main
import (
"net/url"
"os"
"path/filepath"
"testing"
)
func writeSSHConfig(t *testing.T, content string) (cleanup func()) {
t.Helper()
dir := t.TempDir()
sshDir := filepath.Join(dir, ".ssh")
if err := os.Mkdir(sshDir, 0700); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(sshDir, "config"), []byte(content), 0600); err != nil {
t.Fatal(err)
}
old := os.Getenv("HOME")
os.Setenv("HOME", dir)
return func() { os.Setenv("HOME", old) }
}
func TestParseSSHConfigLine(t *testing.T) {
cases := []struct {
line, wantKey, wantVal string
wantOk bool
}{
{" HostName example.com", "HostName", "example.com", true},
{"HostName=example.com", "HostName", "example.com", true},
{"HostName = example.com", "HostName", "example.com", true},
{"# comment", "", "", false},
{"", "", "", false},
{`IdentityFile "~/.ssh/id_rsa"`, "IdentityFile", "~/.ssh/id_rsa", true},
{"Host myalias", "Host", "myalias", true},
{"Host alias1 alias2", "Host", "alias1 alias2", true},
}
for _, c := range cases {
k, v, ok := parseSSHConfigLine(c.line)
if ok != c.wantOk || k != c.wantKey || v != c.wantVal {
t.Errorf("parseSSHConfigLine(%q) = (%q, %q, %v), want (%q, %q, %v)",
c.line, k, v, ok, c.wantKey, c.wantVal, c.wantOk)
}
}
}
func TestExpandSSHNode_NoMatch(t *testing.T) {
cfg := map[string]*sshHostEntry{
"myalias": {hostname: "real.host.com", port: "2222", user: "alice"},
}
if got := expandSSHNode("http://myalias:8080", cfg); got != "http://myalias:8080" {
t.Errorf("non-SSH URL was modified: %q", got)
}
if got := expandSSHNode("ssh://other:22", cfg); got != "ssh://other:22" {
t.Errorf("unknown SSH host was modified: %q", got)
}
}
func TestExpandSSHNode_FullExpansion(t *testing.T) {
keyPath := "/home/alice/.ssh/id_ed25519"
cfg := map[string]*sshHostEntry{
"myalias": {
hostname: "real.host.com",
port: "2222",
user: "alice",
identityFile: keyPath,
},
}
got := expandSSHNode("ssh://myalias", cfg)
u, err := url.Parse(got)
if err != nil {
t.Fatalf("parse result URL: %v", err)
}
if u.Hostname() != "real.host.com" {
t.Errorf("hostname: got %q, want %q", u.Hostname(), "real.host.com")
}
if u.Port() != "2222" {
t.Errorf("port: got %q, want %q", u.Port(), "2222")
}
if u.User.Username() != "alice" {
t.Errorf("user: got %q, want %q", u.User.Username(), "alice")
}
if q := u.Query().Get("privateKeyFile"); q != keyPath {
t.Errorf("privateKeyFile param: got %q, want %q", q, keyPath)
}
}
func TestExpandSSHNode_PreserveExisting(t *testing.T) {
cfg := map[string]*sshHostEntry{
"myalias": {
hostname: "real.host.com",
port: "2222",
user: "alice",
identityFile: "/home/alice/.ssh/id_ed25519",
},
}
got := expandSSHNode("ssh://bob@myalias:22?key=/tmp/other_key", cfg)
u, err := url.Parse(got)
if err != nil {
t.Fatalf("parse result URL: %v", err)
}
if u.Hostname() != "real.host.com" {
t.Errorf("hostname: got %q, want %q", u.Hostname(), "real.host.com")
}
// Port 22 is default; config has 2222 → should be substituted.
if u.Port() != "2222" {
t.Errorf("port: got %q, want %q", u.Port(), "2222")
}
// Existing user in URL must be preserved.
if u.User.Username() != "bob" {
t.Errorf("user: got %q, want %q", u.User.Username(), "bob")
}
// Existing TLS key param must be preserved unchanged.
if q := u.Query().Get("key"); q != "/tmp/other_key" {
t.Errorf("key param: got %q, want %q", q, "/tmp/other_key")
}
// privateKeyFile from config must be added (key is TLS, not SSH).
if q := u.Query().Get("privateKeyFile"); q != "/home/alice/.ssh/id_ed25519" {
t.Errorf("privateKeyFile param: got %q, want %q", q, "/home/alice/.ssh/id_ed25519")
}
}
func TestExpandSSHNode_PrivateKeyFileInURL_Preserved(t *testing.T) {
cfg := map[string]*sshHostEntry{
"srv": {hostname: "1.2.3.4", identityFile: "/home/user/.ssh/id_rsa"},
}
got := expandSSHNode("ssh://srv?privateKeyFile=/custom/key", cfg)
u, _ := url.Parse(got)
// privateKeyFile already set → identityFile from config should NOT override.
if q := u.Query().Get("privateKeyFile"); q != "/custom/key" {
t.Errorf("privateKeyFile param: got %q, want %q", q, "/custom/key")
}
}
func TestReadSSHConfig(t *testing.T) {
cleanup := writeSSHConfig(t, `
# global defaults
Host *
ServerAliveInterval 60
Host bastion
HostName bastion.example.com
Port 2222
User deploy
IdentityFile ~/.ssh/bastion_key
Host dev
HostName 10.0.0.5
User dev
`)
defer cleanup()
cfg := readSSHConfig()
if cfg == nil {
t.Fatal("readSSHConfig returned nil")
}
e, ok := cfg["bastion"]
if !ok {
t.Fatal("expected 'bastion' entry")
}
if e.hostname != "bastion.example.com" {
t.Errorf("hostname: got %q", e.hostname)
}
if e.port != "2222" {
t.Errorf("port: got %q", e.port)
}
if e.user != "deploy" {
t.Errorf("user: got %q", e.user)
}
if filepath.Base(e.identityFile) != "bastion_key" {
t.Errorf("identityFile: got %q", e.identityFile)
}
dev, ok := cfg["dev"]
if !ok {
t.Fatal("expected 'dev' entry")
}
if dev.hostname != "10.0.0.5" {
t.Errorf("dev hostname: got %q", dev.hostname)
}
// Wildcard Host * must not be indexed by exact key.
if _, ok := cfg["*"]; ok {
t.Error("wildcard Host * should not appear in entries map")
}
}
func TestExpandSSHNodes_Integration(t *testing.T) {
cleanup := writeSSHConfig(t, `
Host jump
HostName jump.corp.example.com
Port 2200
User ops
IdentityFile ~/.ssh/jump_key
`)
defer cleanup()
nodes := []string{
"ssh://jump",
"http://proxy:8080",
"ssh://realhost:22",
}
got := expandSSHNodes(nodes)
u0, _ := url.Parse(got[0])
if u0.Hostname() != "jump.corp.example.com" {
t.Errorf("[0] hostname: got %q", u0.Hostname())
}
if u0.Port() != "2200" {
t.Errorf("[0] port: got %q", u0.Port())
}
if u0.User.Username() != "ops" {
t.Errorf("[0] user: got %q", u0.User.Username())
}
if got[1] != "http://proxy:8080" {
t.Errorf("[1] non-SSH URL changed: %q", got[1])
}
// ssh://realhost:22 — not in config, unchanged.
if got[2] != "ssh://realhost:22" {
t.Errorf("[2] unknown host changed: %q", got[2])
}
}

View File

@ -1,5 +1,5 @@
package main
var (
version = "3.2.6"
version = "3.2.7"
)