Refactor http client

This commit is contained in:
Ben Sarmiento
2024-04-27 23:21:49 +02:00
parent 6983f59483
commit dd65d07037

View File

@@ -90,10 +90,20 @@ func NewHTTPClient(
maxConnections := cfg.GetNumOfWorkers() maxConnections := cfg.GetNumOfWorkers()
if maxConnections > 32 { if maxConnections > 32 {
maxConnections = 32 maxConnections = 32 // real-debrid has a limit of 32 connections per server/host
}
client.client.Transport = &http.Transport{
ResponseHeaderTimeout: time.Duration(timeoutSecs) * time.Second,
MaxIdleConns: 0,
MaxConnsPerHost: maxConnections,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return dialer.Dial(network, address)
},
} }
if cfg.ShouldForceIPv6() { if cfg.ShouldForceIPv6() {
// fetch IPv6 hosts
ipv6List, err := hosts.FetchHosts(hosts.IPV6) ipv6List, err := hosts.FetchHosts(hosts.IPV6)
if err != nil { if err != nil {
log.Warnf("Failed to fetch IPv6 hosts: %v", err) log.Warnf("Failed to fetch IPv6 hosts: %v", err)
@@ -103,14 +113,13 @@ func NewHTTPClient(
log.Debugf("Fetched %d IPv6 hosts", len(ipv6List)) log.Debugf("Fetched %d IPv6 hosts", len(ipv6List))
} }
client.client.Transport = &http.Transport{ // replace the default dialer with a custom one that resolves hostnames to IPv6 addresses
ResponseHeaderTimeout: time.Duration(timeoutSecs) * time.Second, client.client.Transport.(*http.Transport).DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
MaxIdleConns: 0, // if address is already cached, use it
MaxConnsPerHost: maxConnections,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
if ipv6Address, ok := client.ipv6.Get(address); ok { if ipv6Address, ok := client.ipv6.Get(address); ok {
return dialer.Dial(network, ipv6Address) return dialer.Dial(network, ipv6Address)
} }
host, port, err := net.SplitHostPort(address) host, port, err := net.SplitHostPort(address)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -119,6 +128,7 @@ func NewHTTPClient(
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, ip := range ips { for _, ip := range ips {
if ip.IP.To4() == nil { // IPv6 address found if ip.IP.To4() == nil { // IPv6 address found
ipv6Address := net.JoinHostPort(ip.IP.String(), port) ipv6Address := net.JoinHostPort(ip.IP.String(), port)
@@ -126,18 +136,20 @@ func NewHTTPClient(
return dialer.Dial(network, ipv6Address) return dialer.Dial(network, ipv6Address)
} }
} }
return dialer.Dial(network, address)
}, // no IPv6 address found, use the original address
log.Warnf("No IPv6 address found for host %s", host)
for _, ip := range ips {
if ip.IP.To4() != nil { // IPv4 address found
ipV4Address := net.JoinHostPort(ip.IP.String(), port)
client.ipv6.Set(address, ipV4Address)
return dialer.Dial(network, ipV4Address)
} }
} else {
client.client.Transport = &http.Transport{
ResponseHeaderTimeout: time.Duration(timeoutSecs) * time.Second,
MaxIdleConns: 0,
MaxConnsPerHost: maxConnections,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return dialer.Dial(network, address)
},
} }
return dialer.Dial(network, address)
}
} }
return &client return &client
@@ -147,11 +159,13 @@ func (r *HTTPClient) Do(req *http.Request) (*http.Response, error) {
if r.bearerToken != "" { if r.bearerToken != "" {
req.Header.Set("Authorization", "Bearer "+r.bearerToken) req.Header.Set("Authorization", "Bearer "+r.bearerToken)
} }
// check if Range header is set // check if Range header is set
reqHasRangeHeader := req.Header.Get("Range") != "" && !strings.HasPrefix(req.Header.Get("Range"), "bytes=0-") reqHasRangeHeader := req.Header.Get("Range") != "" && !strings.HasPrefix(req.Header.Get("Range"), "bytes=0-")
var resp *http.Response var resp *http.Response
var err error var err error
attempt := 0 attempt := 0
var origBody []byte var origBody []byte
if req.Method == "POST" { if req.Method == "POST" {
@@ -182,6 +196,7 @@ func (r *HTTPClient) Do(req *http.Request) (*http.Response, error) {
if incr > 0 { if incr > 0 {
attempt += incr attempt += incr
if attempt > r.maxRetries { if attempt > r.maxRetries {
r.log.Warnf("Request to %s failed after %d attempts", req.URL.String(), r.maxRetries)
break break
} }
time.Sleep(r.backoff(attempt)) time.Sleep(r.backoff(attempt))