diff --git a/internal/app.go b/internal/app.go index 228d278..651968f 100644 --- a/internal/app.go +++ b/internal/app.go @@ -12,7 +12,6 @@ import ( "github.com/debridmediamanager/zurg/internal/handlers" "github.com/debridmediamanager/zurg/internal/torrent" "github.com/debridmediamanager/zurg/internal/universal" - "github.com/debridmediamanager/zurg/pkg/hosts" "github.com/debridmediamanager/zurg/pkg/http" "github.com/debridmediamanager/zurg/pkg/logutil" "github.com/debridmediamanager/zurg/pkg/premium" @@ -36,7 +35,7 @@ func MainApp(configPath string) { os.Exit(1) } - apiClient := http.NewHTTPClient(config.GetToken(), config.GetRetriesUntilFailed(), config.GetRealDebridTimeout(), nil, config, log.Named("httpclient")) + apiClient := http.NewHTTPClient(config.GetToken(), config.GetRetriesUntilFailed(), config.GetRealDebridTimeout(), false, config, log.Named("httpclient")) rd := realdebrid.NewRealDebrid(apiClient, log.Named("realdebrid")) @@ -52,14 +51,7 @@ func MainApp(configPath string) { utils.EnsureDirExists("data") // Ensure the data directory exists torrentMgr := torrent.NewTorrentManager(config, rd, p, log.Named("manager")) - var ipv6List []string - if config.ShouldForceIPv6() { - ipv6List, err = hosts.FetchHosts(hosts.IPV6) - if err != nil { - zurglog.Warnf("Failed to fetch IPv6 hosts: %v", err) - } - } - downloadClient := http.NewHTTPClient(config.GetToken(), config.GetRetriesUntilFailed(), 0, ipv6List, config, log.Named("dlclient")) + downloadClient := http.NewHTTPClient(config.GetToken(), config.GetRetriesUntilFailed(), 0, true, config, log.Named("dlclient")) downloader := universal.NewDownloader(downloadClient) router := chi.NewRouter() diff --git a/pkg/http/client.go b/pkg/http/client.go index 5b986cf..f7d100e 100644 --- a/pkg/http/client.go +++ b/pkg/http/client.go @@ -13,6 +13,7 @@ import ( "time" "github.com/debridmediamanager/zurg/internal/config" + "github.com/debridmediamanager/zurg/pkg/hosts" "github.com/debridmediamanager/zurg/pkg/logutil" cmap "github.com/orcaman/concurrent-map/v2" @@ -23,15 +24,16 @@ const ( ) type HTTPClient struct { - client *http.Client - maxRetries int - backoff func(attempt int) time.Duration - getRetryIncr func(resp *http.Response, hasRangeHeader bool, err error) int - bearerToken string - restrictToHosts []string - cfg config.ConfigInterface - ipv6 cmap.ConcurrentMap[string, string] - log *logutil.Logger + client *http.Client + maxRetries int + backoff func(attempt int) time.Duration + getRetryIncr func(resp *http.Response, hasRangeHeader bool, err error) int + bearerToken string + ensureIPv6Host bool + cfg config.ConfigInterface + ipv6 cmap.ConcurrentMap[string, string] + ipv6Hosts []string + log *logutil.Logger } // { @@ -48,7 +50,7 @@ func (e *ErrorResponse) Error() string { return fmt.Sprintf("api response error: %s (code: %d)", e.Message, e.Code) } -func NewHTTPClient(token string, maxRetries int, timeoutSecs int, restrictToHosts []string, cfg config.ConfigInterface, log *logutil.Logger) *HTTPClient { +func NewHTTPClient(token string, maxRetries int, timeoutSecs int, ensureIPv6Host bool, cfg config.ConfigInterface, log *logutil.Logger) *HTTPClient { client := HTTPClient{ bearerToken: token, client: &http.Client{ @@ -102,41 +104,51 @@ func NewHTTPClient(token string, maxRetries int, timeoutSecs int, restrictToHost } return RATE_LIMIT_FACTOR }, - restrictToHosts: restrictToHosts, - cfg: cfg, - ipv6: cmap.New[string](), - log: log, + ensureIPv6Host: ensureIPv6Host, + cfg: cfg, + ipv6: cmap.New[string](), + log: log, } - if cfg.ShouldForceIPv6() { - dialer := &net.Dialer{} - dialContext := func(ctx context.Context, network, address string) (net.Conn, error) { - if ipv6Address, ok := client.ipv6.Get(address); ok { + if !cfg.ShouldForceIPv6() { + return &client + } + + // set ipv6 hosts + ipv6List, err := hosts.FetchHosts(hosts.IPV6) + if err != nil { + log.Warnf("Failed to fetch IPv6 hosts: %v", err) + } + client.ipv6Hosts = ipv6List + + // set ipv6 transport + dialer := &net.Dialer{} + dialContext := func(ctx context.Context, network, address string) (net.Conn, error) { + if ipv6Address, ok := client.ipv6.Get(address); ok { + return dialer.DialContext(ctx, network, ipv6Address) + } + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, err + } + for _, ip := range ips { + if ip.IP.To4() == nil { // IPv6 address found + ip6Host := ip.IP.String() + ipv6Address := net.JoinHostPort(ip6Host, port) + client.ipv6.Set(address, ipv6Address) return dialer.DialContext(ctx, network, ipv6Address) } - host, port, err := net.SplitHostPort(address) - if err != nil { - return nil, err - } - ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) - if err != nil { - return nil, err - } - for _, ip := range ips { - if ip.IP.To4() == nil { // IPv6 address found - ip6Host := ip.IP.String() - ipv6Address := net.JoinHostPort(ip6Host, port) - client.ipv6.Set(address, ipv6Address) - return dialer.DialContext(ctx, network, ipv6Address) - } - } - return dialer.DialContext(ctx, network, address) } - transport := &http.Transport{ - DialContext: dialContext, - } - client.client.Transport = transport + return dialer.DialContext(ctx, network, address) } + transport := &http.Transport{ + DialContext: dialContext, + } + client.client.Transport = transport return &client } @@ -187,11 +199,11 @@ func (r *HTTPClient) Do(req *http.Request) (*http.Response, error) { } func (r *HTTPClient) replaceHostIfNeeded(req *http.Request) { - if !r.cfg.ShouldForceIPv6() { + if !r.ensureIPv6Host || !r.cfg.ShouldForceIPv6() { return } - if len(r.restrictToHosts) == 0 { - // replace .com with .cloud + // if no hosts are found, just replace .com with .cloud + if len(r.ipv6Hosts) == 0 { host := req.URL.Host if strings.HasSuffix(host, ".com") { newHost := strings.Replace(host, ".com", ".cloud", 1) @@ -201,21 +213,23 @@ func (r *HTTPClient) replaceHostIfNeeded(req *http.Request) { } return } + // if hosts are found, ensure the host is an IPv6 host host, port, err := net.SplitHostPort(req.URL.Host) if err != nil { host = req.URL.Host // Use the host without port port = "443" // Default HTTPS port } found := false - for _, h := range r.restrictToHosts { + for _, h := range r.ipv6Hosts { if h == host { found = true break } } + // if host is not an IPv6 host, replace it with a random IPv6 host if !found { r.log.Warnf("Host %s is not an IPv6 host (ensure you have preferred_hosts properly set in your config.yml, if unset, run `zurg network-test -t ipv6`)", host) - randomHost := r.restrictToHosts[rand.Intn(len(r.restrictToHosts))] + randomHost := r.ipv6Hosts[rand.Intn(len(r.ipv6Hosts))] newHost := net.JoinHostPort(randomHost, port) newURL := *req.URL newURL.Host = newHost