diff --git a/backend/xray/api/base.go b/backend/xray/api/base.go index c469736..715c388 100644 --- a/backend/xray/api/base.go +++ b/backend/xray/api/base.go @@ -1,7 +1,11 @@ package api import ( + "context" "fmt" + "net" + "time" + "github.com/xtls/xray-core/app/proxyman/command" statsService "github.com/xtls/xray-core/app/stats/command" "google.golang.org/grpc" @@ -16,9 +20,26 @@ type XrayHandler struct { func NewXrayAPI(apiPort int) (*XrayHandler, error) { x := &XrayHandler{} + target := fmt.Sprintf("127.0.0.1:%v", apiPort) + dialer := &net.Dialer{ + Timeout: 5 * time.Second, + LocalAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1")}, + } var err error - x.GrpcClient, err = grpc.NewClient(fmt.Sprintf("127.0.0.1:%v", apiPort), grpc.WithTransportCredentials(insecure.NewCredentials())) + x.GrpcClient, err = grpc.NewClient( + target, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + conn, dialErr := dialer.DialContext(ctx, "tcp4", addr) + if dialErr == nil { + return conn, nil + } + + var fallback net.Dialer + return fallback.DialContext(ctx, "tcp", addr) + }), + ) if err != nil { return nil, err diff --git a/backend/xray/config.go b/backend/xray/config.go index 5d16914..0b21858 100644 --- a/backend/xray/config.go +++ b/backend/xray/config.go @@ -4,7 +4,9 @@ import ( "encoding/json" "fmt" "log" + "net" "slices" + "sort" "strings" "sync" @@ -405,6 +407,144 @@ func filterRules(rules []json.RawMessage, apiTag string) ([]json.RawMessage, err return filtered, nil } +var privateCIDRs = []string{ + "0.0.0.0/8", + "10.0.0.0/8", + "100.64.0.0/10", + "127.0.0.0/8", + "169.254.0.0/16", + "172.16.0.0/12", + "192.0.0.0/24", + "192.168.0.0/16", + "198.18.0.0/15", + "224.0.0.0/4", + "240.0.0.0/4", + "::/128", + "::1/128", + "fc00::/7", + "fe80::/10", +} + +func replaceGeoIPPrivate(values any) (any, bool) { + list, ok := values.([]any) + if !ok { + return values, false + } + + updated := make([]any, 0, len(list)+len(privateCIDRs)) + changed := false + for _, entry := range list { + s, strOK := entry.(string) + if strOK && strings.EqualFold(s, "geoip:private") { + for _, cidr := range privateCIDRs { + updated = append(updated, cidr) + } + changed = true + continue + } + updated = append(updated, entry) + } + + if !changed { + return values, false + } + + return updated, true +} + +func normalizeGeoIPPrivateRules(rules []json.RawMessage) ([]json.RawMessage, error) { + if rules == nil { + return []json.RawMessage{}, nil + } + + normalized := make([]json.RawMessage, 0, len(rules)) + for _, raw := range rules { + var obj map[string]any + if err := json.Unmarshal(raw, &obj); err != nil { + return nil, fmt.Errorf("invalid JSON in rule: %w", err) + } + + ruleChanged := false + if ip, ok := obj["ip"]; ok { + newIP, changed := replaceGeoIPPrivate(ip) + if changed { + obj["ip"] = newIP + ruleChanged = true + } + } + + if source, ok := obj["source"]; ok { + newSource, changed := replaceGeoIPPrivate(source) + if changed { + obj["source"] = newSource + ruleChanged = true + } + } + + if !ruleChanged { + normalized = append(normalized, raw) + continue + } + + rawBytes, err := json.Marshal(obj) + if err != nil { + return nil, fmt.Errorf("failed to marshal normalized rule: %w", err) + } + normalized = append(normalized, json.RawMessage(rawBytes)) + } + + return normalized, nil +} + +func apiRuleSources() []string { + seen := map[string]struct{}{ + "127.0.0.1": {}, + "::1": {}, + } + + ifaces, err := net.Interfaces() + if err != nil { + return []string{"127.0.0.1", "::1"} + } + + for _, iface := range ifaces { + if iface.Flags&net.FlagUp == 0 { + continue + } + + addrs, err := iface.Addrs() + if err != nil { + continue + } + + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + default: + continue + } + + if ip == nil || ip.IsUnspecified() { + continue + } + + seen[ip.String()] = struct{}{} + } + } + + sources := make([]string, 0, len(seen)) + for source := range seen { + sources = append(sources, source) + } + sort.Strings(sources) + + return sources +} + func (c *Config) ApplyAPI(apiPort int) (err error) { // Remove the existing inbound with the API_INBOUND tag for i, inbound := range c.InboundConfigs { @@ -425,7 +565,14 @@ func (c *Config) ApplyAPI(apiPort int) (err error) { } rules := c.RouterConfig.RuleList + rules, err = normalizeGeoIPPrivateRules(rules) + if err != nil { + return err + } c.RouterConfig.RuleList, err = filterRules(rules, apiTag) + if err != nil { + return err + } c.checkPolicy() @@ -442,7 +589,7 @@ func (c *Config) ApplyAPI(apiPort int) (err error) { rule := map[string]any{ "inboundTag": []string{"API_INBOUND"}, - "source": []string{"127.0.0.1"}, + "source": apiRuleSources(), "outboundTag": "API", "type": "field", } diff --git a/controller/rest/base.go b/controller/rest/base.go index 533a86f..5a4609b 100644 --- a/controller/rest/base.go +++ b/controller/rest/base.go @@ -34,13 +34,13 @@ func (s *Service) Start(w http.ResponseWriter, r *http.Request) { s.Disconnect() } - s.Connect(ip, keepAlive) - if err = s.StartBackend(ctx, backendType); err != nil { http.Error(w, err.Error(), http.StatusServiceUnavailable) return } + s.Connect(ip, keepAlive) + common.SendProtoResponse(w, s.BaseInfoResponse()) }