Fix ipv6 stuffs 4

This commit is contained in:
Ben Sarmiento
2024-01-11 03:02:05 +01:00
parent 7ea90d0754
commit 9938c1134d
2 changed files with 60 additions and 54 deletions

View File

@@ -12,7 +12,6 @@ import (
"github.com/debridmediamanager/zurg/internal/handlers" "github.com/debridmediamanager/zurg/internal/handlers"
"github.com/debridmediamanager/zurg/internal/torrent" "github.com/debridmediamanager/zurg/internal/torrent"
"github.com/debridmediamanager/zurg/internal/universal" "github.com/debridmediamanager/zurg/internal/universal"
"github.com/debridmediamanager/zurg/pkg/hosts"
"github.com/debridmediamanager/zurg/pkg/http" "github.com/debridmediamanager/zurg/pkg/http"
"github.com/debridmediamanager/zurg/pkg/logutil" "github.com/debridmediamanager/zurg/pkg/logutil"
"github.com/debridmediamanager/zurg/pkg/premium" "github.com/debridmediamanager/zurg/pkg/premium"
@@ -36,7 +35,7 @@ func MainApp(configPath string) {
os.Exit(1) os.Exit(1)
} }
apiClient := http.NewHTTPClient(config.GetToken(), config.GetRetriesUntilFailed(), config.GetRealDebridTimeout(), nil, config, log.Named("httpclient")) apiClient := http.NewHTTPClient(config.GetToken(), config.GetRetriesUntilFailed(), config.GetRealDebridTimeout(), false, config, log.Named("httpclient"))
rd := realdebrid.NewRealDebrid(apiClient, log.Named("realdebrid")) rd := realdebrid.NewRealDebrid(apiClient, log.Named("realdebrid"))
@@ -52,14 +51,7 @@ func MainApp(configPath string) {
utils.EnsureDirExists("data") // Ensure the data directory exists utils.EnsureDirExists("data") // Ensure the data directory exists
torrentMgr := torrent.NewTorrentManager(config, rd, p, log.Named("manager")) torrentMgr := torrent.NewTorrentManager(config, rd, p, log.Named("manager"))
var ipv6List []string downloadClient := http.NewHTTPClient(config.GetToken(), config.GetRetriesUntilFailed(), 0, true, config, log.Named("dlclient"))
if config.ShouldForceIPv6() {
ipv6List, err = hosts.FetchHosts(hosts.IPV6)
if err != nil {
zurglog.Warnf("Failed to fetch IPv6 hosts: %v", err)
}
}
downloadClient := http.NewHTTPClient(config.GetToken(), config.GetRetriesUntilFailed(), 0, ipv6List, config, log.Named("dlclient"))
downloader := universal.NewDownloader(downloadClient) downloader := universal.NewDownloader(downloadClient)
router := chi.NewRouter() router := chi.NewRouter()

View File

@@ -13,6 +13,7 @@ import (
"time" "time"
"github.com/debridmediamanager/zurg/internal/config" "github.com/debridmediamanager/zurg/internal/config"
"github.com/debridmediamanager/zurg/pkg/hosts"
"github.com/debridmediamanager/zurg/pkg/logutil" "github.com/debridmediamanager/zurg/pkg/logutil"
cmap "github.com/orcaman/concurrent-map/v2" cmap "github.com/orcaman/concurrent-map/v2"
@@ -28,9 +29,10 @@ type HTTPClient struct {
backoff func(attempt int) time.Duration backoff func(attempt int) time.Duration
getRetryIncr func(resp *http.Response, hasRangeHeader bool, err error) int getRetryIncr func(resp *http.Response, hasRangeHeader bool, err error) int
bearerToken string bearerToken string
restrictToHosts []string ensureIPv6Host bool
cfg config.ConfigInterface cfg config.ConfigInterface
ipv6 cmap.ConcurrentMap[string, string] ipv6 cmap.ConcurrentMap[string, string]
ipv6Hosts []string
log *logutil.Logger log *logutil.Logger
} }
@@ -48,7 +50,7 @@ func (e *ErrorResponse) Error() string {
return fmt.Sprintf("api response error: %s (code: %d)", e.Message, e.Code) return fmt.Sprintf("api response error: %s (code: %d)", e.Message, e.Code)
} }
func NewHTTPClient(token string, maxRetries int, timeoutSecs int, restrictToHosts []string, cfg config.ConfigInterface, log *logutil.Logger) *HTTPClient { func NewHTTPClient(token string, maxRetries int, timeoutSecs int, ensureIPv6Host bool, cfg config.ConfigInterface, log *logutil.Logger) *HTTPClient {
client := HTTPClient{ client := HTTPClient{
bearerToken: token, bearerToken: token,
client: &http.Client{ client: &http.Client{
@@ -102,13 +104,24 @@ func NewHTTPClient(token string, maxRetries int, timeoutSecs int, restrictToHost
} }
return RATE_LIMIT_FACTOR return RATE_LIMIT_FACTOR
}, },
restrictToHosts: restrictToHosts, ensureIPv6Host: ensureIPv6Host,
cfg: cfg, cfg: cfg,
ipv6: cmap.New[string](), ipv6: cmap.New[string](),
log: log, log: log,
} }
if cfg.ShouldForceIPv6() { if !cfg.ShouldForceIPv6() {
return &client
}
// set ipv6 hosts
ipv6List, err := hosts.FetchHosts(hosts.IPV6)
if err != nil {
log.Warnf("Failed to fetch IPv6 hosts: %v", err)
}
client.ipv6Hosts = ipv6List
// set ipv6 transport
dialer := &net.Dialer{} dialer := &net.Dialer{}
dialContext := func(ctx context.Context, network, address string) (net.Conn, error) { dialContext := func(ctx context.Context, network, address string) (net.Conn, error) {
if ipv6Address, ok := client.ipv6.Get(address); ok { if ipv6Address, ok := client.ipv6.Get(address); ok {
@@ -136,7 +149,6 @@ func NewHTTPClient(token string, maxRetries int, timeoutSecs int, restrictToHost
DialContext: dialContext, DialContext: dialContext,
} }
client.client.Transport = transport client.client.Transport = transport
}
return &client return &client
} }
@@ -187,11 +199,11 @@ func (r *HTTPClient) Do(req *http.Request) (*http.Response, error) {
} }
func (r *HTTPClient) replaceHostIfNeeded(req *http.Request) { func (r *HTTPClient) replaceHostIfNeeded(req *http.Request) {
if !r.cfg.ShouldForceIPv6() { if !r.ensureIPv6Host || !r.cfg.ShouldForceIPv6() {
return return
} }
if len(r.restrictToHosts) == 0 { // if no hosts are found, just replace .com with .cloud
// replace .com with .cloud if len(r.ipv6Hosts) == 0 {
host := req.URL.Host host := req.URL.Host
if strings.HasSuffix(host, ".com") { if strings.HasSuffix(host, ".com") {
newHost := strings.Replace(host, ".com", ".cloud", 1) newHost := strings.Replace(host, ".com", ".cloud", 1)
@@ -201,21 +213,23 @@ func (r *HTTPClient) replaceHostIfNeeded(req *http.Request) {
} }
return return
} }
// if hosts are found, ensure the host is an IPv6 host
host, port, err := net.SplitHostPort(req.URL.Host) host, port, err := net.SplitHostPort(req.URL.Host)
if err != nil { if err != nil {
host = req.URL.Host // Use the host without port host = req.URL.Host // Use the host without port
port = "443" // Default HTTPS port port = "443" // Default HTTPS port
} }
found := false found := false
for _, h := range r.restrictToHosts { for _, h := range r.ipv6Hosts {
if h == host { if h == host {
found = true found = true
break break
} }
} }
// if host is not an IPv6 host, replace it with a random IPv6 host
if !found { if !found {
r.log.Warnf("Host %s is not an IPv6 host (ensure you have preferred_hosts properly set in your config.yml, if unset, run `zurg network-test -t ipv6`)", host) r.log.Warnf("Host %s is not an IPv6 host (ensure you have preferred_hosts properly set in your config.yml, if unset, run `zurg network-test -t ipv6`)", host)
randomHost := r.restrictToHosts[rand.Intn(len(r.restrictToHosts))] randomHost := r.ipv6Hosts[rand.Intn(len(r.ipv6Hosts))]
newHost := net.JoinHostPort(randomHost, port) newHost := net.JoinHostPort(randomHost, port)
newURL := *req.URL newURL := *req.URL
newURL.Host = newHost newURL.Host = newHost