Refactor http client

This commit is contained in:
Ben Sarmiento
2024-04-27 23:21:49 +02:00
parent 6983f59483
commit dd65d07037

View File

@@ -90,10 +90,20 @@ func NewHTTPClient(
maxConnections := cfg.GetNumOfWorkers()
if maxConnections > 32 {
maxConnections = 32
maxConnections = 32 // real-debrid has a limit of 32 connections per server/host
}
client.client.Transport = &http.Transport{
ResponseHeaderTimeout: time.Duration(timeoutSecs) * time.Second,
MaxIdleConns: 0,
MaxConnsPerHost: maxConnections,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return dialer.Dial(network, address)
},
}
if cfg.ShouldForceIPv6() {
// fetch IPv6 hosts
ipv6List, err := hosts.FetchHosts(hosts.IPV6)
if err != nil {
log.Warnf("Failed to fetch IPv6 hosts: %v", err)
@@ -103,41 +113,43 @@ func NewHTTPClient(
log.Debugf("Fetched %d IPv6 hosts", len(ipv6List))
}
client.client.Transport = &http.Transport{
ResponseHeaderTimeout: time.Duration(timeoutSecs) * time.Second,
MaxIdleConns: 0,
MaxConnsPerHost: maxConnections,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
if ipv6Address, ok := client.ipv6.Get(address); ok {
// replace the default dialer with a custom one that resolves hostnames to IPv6 addresses
client.client.Transport.(*http.Transport).DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
// if address is already cached, use it
if ipv6Address, ok := client.ipv6.Get(address); ok {
return dialer.Dial(network, ipv6Address)
}
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil {
return nil, err
}
for _, ip := range ips {
if ip.IP.To4() == nil { // IPv6 address found
ipv6Address := net.JoinHostPort(ip.IP.String(), port)
client.ipv6.Set(address, ipv6Address)
return dialer.Dial(network, ipv6Address)
}
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
// no IPv6 address found, use the original address
log.Warnf("No IPv6 address found for host %s", host)
for _, ip := range ips {
if ip.IP.To4() != nil { // IPv4 address found
ipV4Address := net.JoinHostPort(ip.IP.String(), port)
client.ipv6.Set(address, ipV4Address)
return dialer.Dial(network, ipV4Address)
}
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil {
return nil, err
}
for _, ip := range ips {
if ip.IP.To4() == nil { // IPv6 address found
ipv6Address := net.JoinHostPort(ip.IP.String(), port)
client.ipv6.Set(address, ipv6Address)
return dialer.Dial(network, ipv6Address)
}
}
return dialer.Dial(network, address)
},
}
} else {
client.client.Transport = &http.Transport{
ResponseHeaderTimeout: time.Duration(timeoutSecs) * time.Second,
MaxIdleConns: 0,
MaxConnsPerHost: maxConnections,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return dialer.Dial(network, address)
},
}
return dialer.Dial(network, address)
}
}
return &client
@@ -147,11 +159,13 @@ func (r *HTTPClient) Do(req *http.Request) (*http.Response, error) {
if r.bearerToken != "" {
req.Header.Set("Authorization", "Bearer "+r.bearerToken)
}
// check if Range header is set
reqHasRangeHeader := req.Header.Get("Range") != "" && !strings.HasPrefix(req.Header.Get("Range"), "bytes=0-")
var resp *http.Response
var err error
attempt := 0
var origBody []byte
if req.Method == "POST" {
@@ -182,6 +196,7 @@ func (r *HTTPClient) Do(req *http.Request) (*http.Response, error) {
if incr > 0 {
attempt += incr
if attempt > r.maxRetries {
r.log.Warnf("Request to %s failed after %d attempts", req.URL.String(), r.maxRetries)
break
}
time.Sleep(r.backoff(attempt))