Replace host on every retry

This commit is contained in:
Ben Sarmiento
2024-01-26 18:33:15 +01:00
parent 8d4cdbbd1f
commit 1ff8cf2dfc

View File

@@ -10,6 +10,7 @@ import (
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"strconv"
"strings" "strings"
"time" "time"
@@ -25,6 +26,7 @@ import (
type HTTPClient struct { type HTTPClient struct {
client *http.Client client *http.Client
maxRetries int maxRetries int
timeoutSecs int
backoff func(attempt int) time.Duration backoff func(attempt int) time.Duration
getRetryIncr func(resp *http.Response, reqHasRangeHeader bool, err error) int getRetryIncr func(resp *http.Response, reqHasRangeHeader bool, err error) int
bearerToken string bearerToken string
@@ -54,6 +56,7 @@ func NewHTTPClient(token string, maxRetries int, timeoutSecs int, ensureIPv6Host
bearerToken: token, bearerToken: token,
client: &http.Client{}, client: &http.Client{},
maxRetries: maxRetries, maxRetries: maxRetries,
timeoutSecs: timeoutSecs,
backoff: func(attempt int) time.Duration { backoff: func(attempt int) time.Duration {
maxDuration := 60 maxDuration := 60
backoff := int(math.Pow(2, float64(attempt))) backoff := int(math.Pow(2, float64(attempt)))
@@ -182,7 +185,6 @@ func (r *HTTPClient) Do(req *http.Request) (*http.Response, error) {
if r.bearerToken != "" { if r.bearerToken != "" {
req.Header.Set("Authorization", "Bearer "+r.bearerToken) req.Header.Set("Authorization", "Bearer "+r.bearerToken)
} }
r.replaceHostIfNeeded(req)
// check if Range header is set // check if Range header is set
reqHasRangeHeader := req.Header.Get("Range") != "" && req.Header.Get("Range") != "bytes=0-" reqHasRangeHeader := req.Header.Get("Range") != "" && req.Header.Get("Range") != "bytes=0-"
@@ -190,7 +192,21 @@ func (r *HTTPClient) Do(req *http.Request) (*http.Response, error) {
var err error var err error
attempt := 0 attempt := 0
for { for {
r.replaceHostIfNeeded(req)
ctx, cancel := context.WithTimeout(req.Context(), time.Duration(r.timeoutSecs)*time.Second)
defer cancel()
req = req.WithContext(ctx)
resp, err = r.client.Do(req) resp, err = r.client.Do(req)
select {
case <-ctx.Done():
if ctx.Err() == context.DeadlineExceeded {
err = fmt.Errorf("request timed out after %d seconds", r.timeoutSecs)
}
default:
}
if resp != nil && (resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusPartialContent) { if resp != nil && (resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusPartialContent) {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
if body != nil { if body != nil {
@@ -221,19 +237,18 @@ func (r *HTTPClient) Do(req *http.Request) (*http.Response, error) {
} }
func (r *HTTPClient) replaceHostIfNeeded(req *http.Request) { func (r *HTTPClient) replaceHostIfNeeded(req *http.Request) {
if !r.ensureIPv6Host || !r.cfg.ShouldForceIPv6() { if !r.ensureIPv6Host && !r.cfg.ShouldForceIPv6() || !strings.HasSuffix(req.URL.Host, "real-debrid.com") {
return return
} }
// if no hosts are found, just replace .com with .cloud // get subdomain of req.URL.Host
if len(r.ipv6Hosts) == 0 { subdomain := strings.Split(req.URL.Host, ".")[0]
host := req.URL.Host // check if subdomain is numeric
if strings.HasSuffix(host, ".com") { _, err := strconv.Atoi(subdomain)
newHost := strings.Replace(host, ".com", ".cloud", 1) if err == nil {
req.Host = newHost // subdomain is numeric, replace it with .cloud
req.URL.Host = newHost req.URL.Host = strings.Replace(req.URL.Host, ".com", ".cloud", 1)
}
} }
// if hosts are found, ensure the host is an IPv6 host // check if host is in the list of IPv6 hosts
found := false found := false
for _, h := range r.ipv6Hosts { for _, h := range r.ipv6Hosts {
if h == req.URL.Host { if h == req.URL.Host {
@@ -241,12 +256,9 @@ func (r *HTTPClient) replaceHostIfNeeded(req *http.Request) {
break break
} }
} }
// if host is not an IPv6 host, replace it with a random IPv6 host
if !found { if !found {
r.log.Warnf("Host %s is not an IPv6 host, replacing with a random IPv6 host (ensure you have preferred_hosts properly set in your config.yml, if unset, run `zurg network-test -t ipv6`)", req.URL.Host) // random IPv6 host
newHost := r.ipv6Hosts[rand.Intn(len(r.ipv6Hosts))] req.URL.Host = r.ipv6Hosts[rand.Intn(len(r.ipv6Hosts))]
req.Host = newHost
req.URL.Host = newHost
} }
} }