Refactor http client
This commit is contained in:
@@ -90,10 +90,20 @@ func NewHTTPClient(
|
||||
|
||||
maxConnections := cfg.GetNumOfWorkers()
|
||||
if maxConnections > 32 {
|
||||
maxConnections = 32
|
||||
maxConnections = 32 // real-debrid has a limit of 32 connections per server/host
|
||||
}
|
||||
|
||||
client.client.Transport = &http.Transport{
|
||||
ResponseHeaderTimeout: time.Duration(timeoutSecs) * time.Second,
|
||||
MaxIdleConns: 0,
|
||||
MaxConnsPerHost: maxConnections,
|
||||
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return dialer.Dial(network, address)
|
||||
},
|
||||
}
|
||||
|
||||
if cfg.ShouldForceIPv6() {
|
||||
// fetch IPv6 hosts
|
||||
ipv6List, err := hosts.FetchHosts(hosts.IPV6)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to fetch IPv6 hosts: %v", err)
|
||||
@@ -103,41 +113,43 @@ func NewHTTPClient(
|
||||
log.Debugf("Fetched %d IPv6 hosts", len(ipv6List))
|
||||
}
|
||||
|
||||
client.client.Transport = &http.Transport{
|
||||
ResponseHeaderTimeout: time.Duration(timeoutSecs) * time.Second,
|
||||
MaxIdleConns: 0,
|
||||
MaxConnsPerHost: maxConnections,
|
||||
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if ipv6Address, ok := client.ipv6.Get(address); ok {
|
||||
// replace the default dialer with a custom one that resolves hostnames to IPv6 addresses
|
||||
client.client.Transport.(*http.Transport).DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
// if address is already cached, use it
|
||||
if ipv6Address, ok := client.ipv6.Get(address); ok {
|
||||
return dialer.Dial(network, ipv6Address)
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
if ip.IP.To4() == nil { // IPv6 address found
|
||||
ipv6Address := net.JoinHostPort(ip.IP.String(), port)
|
||||
client.ipv6.Set(address, ipv6Address)
|
||||
return dialer.Dial(network, ipv6Address)
|
||||
}
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// no IPv6 address found, use the original address
|
||||
log.Warnf("No IPv6 address found for host %s", host)
|
||||
for _, ip := range ips {
|
||||
if ip.IP.To4() != nil { // IPv4 address found
|
||||
ipV4Address := net.JoinHostPort(ip.IP.String(), port)
|
||||
client.ipv6.Set(address, ipV4Address)
|
||||
return dialer.Dial(network, ipV4Address)
|
||||
}
|
||||
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, ip := range ips {
|
||||
if ip.IP.To4() == nil { // IPv6 address found
|
||||
ipv6Address := net.JoinHostPort(ip.IP.String(), port)
|
||||
client.ipv6.Set(address, ipv6Address)
|
||||
return dialer.Dial(network, ipv6Address)
|
||||
}
|
||||
}
|
||||
return dialer.Dial(network, address)
|
||||
},
|
||||
}
|
||||
} else {
|
||||
client.client.Transport = &http.Transport{
|
||||
ResponseHeaderTimeout: time.Duration(timeoutSecs) * time.Second,
|
||||
MaxIdleConns: 0,
|
||||
MaxConnsPerHost: maxConnections,
|
||||
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return dialer.Dial(network, address)
|
||||
},
|
||||
}
|
||||
|
||||
return dialer.Dial(network, address)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return &client
|
||||
@@ -147,11 +159,13 @@ func (r *HTTPClient) Do(req *http.Request) (*http.Response, error) {
|
||||
if r.bearerToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+r.bearerToken)
|
||||
}
|
||||
|
||||
// check if Range header is set
|
||||
reqHasRangeHeader := req.Header.Get("Range") != "" && !strings.HasPrefix(req.Header.Get("Range"), "bytes=0-")
|
||||
|
||||
var resp *http.Response
|
||||
var err error
|
||||
|
||||
attempt := 0
|
||||
var origBody []byte
|
||||
if req.Method == "POST" {
|
||||
@@ -182,6 +196,7 @@ func (r *HTTPClient) Do(req *http.Request) (*http.Response, error) {
|
||||
if incr > 0 {
|
||||
attempt += incr
|
||||
if attempt > r.maxRetries {
|
||||
r.log.Warnf("Request to %s failed after %d attempts", req.URL.String(), r.maxRetries)
|
||||
break
|
||||
}
|
||||
time.Sleep(r.backoff(attempt))
|
||||
|
||||
Reference in New Issue
Block a user