Files
zurg/pkg/http/client.go
2024-08-21 23:27:22 +02:00

367 lines
10 KiB
Go

package http
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"math"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/debridmediamanager/zurg/pkg/logutil"
http_dialer "github.com/mwitkow/go-http-dialer"
"golang.org/x/net/proxy"
cmap "github.com/orcaman/concurrent-map/v2"
)
type HTTPClient struct {
token string
client *http.Client
maxRetries int
timeoutSecs int
rateLimitSleepSecs int
backoff func(int, int) time.Duration
dnsCache cmap.ConcurrentMap[string, string]
hosts []string
rateLimiter *RateLimiter
log *logutil.Logger
}
type ApiErrorResponse struct {
Message string `json:"error"`
Code int `json:"error_code"`
}
func (e *ApiErrorResponse) Error() string {
return fmt.Sprintf("api response error: %s (code: %d)", e.Message, e.Code)
}
type DownloadErrorResponse struct {
Message string
Code int
}
func (e *DownloadErrorResponse) Error() string {
return fmt.Sprintf("download response error: %s (code: %d)", e.Message, e.Code)
}
func NewHTTPClient(
token string,
maxRetries int,
timeoutSecs int,
forceIPv6 bool,
hosts []string,
proxyURL string,
rateLimiter *RateLimiter,
log *logutil.Logger,
) *HTTPClient {
client := HTTPClient{
token: token,
client: &http.Client{},
maxRetries: maxRetries,
timeoutSecs: timeoutSecs,
rateLimitSleepSecs: 1,
backoff: backoffFunc,
dnsCache: cmap.New[string](),
hosts: hosts,
rateLimiter: rateLimiter,
log: log,
}
var dialer proxy.Dialer = &net.Dialer{
Timeout: time.Duration(timeoutSecs) * time.Second, // timeout for dns resolution, tcp handshake
}
if proxyURLString := proxyURL; proxyURLString != "" {
proxyURL, err := url.Parse(proxyURLString)
if err != nil {
log.Errorf("Failed to parse proxy URL: %v", err)
return nil
}
dialer, err = client.proxyDialer(proxyURL)
if err != nil {
log.Errorf("Failed to create proxy dialer: %v", err)
return nil
}
}
client.client.Transport = &http.Transport{
ResponseHeaderTimeout: time.Duration(timeoutSecs) * time.Second,
MaxIdleConnsPerHost: 32,
MaxConnsPerHost: 32,
IdleConnTimeout: 90 * time.Second,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return dialer.Dial(network, address)
},
}
if forceIPv6 {
// replace the default dialer with a custom one that resolves hostnames to IPv6 addresses
client.client.Transport.(*http.Transport).DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
// if address is already cached, use it
if ipv6Address, ok := client.dnsCache.Get(address); ok {
return dialer.Dial(network, ipv6Address)
}
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil {
return nil, err
}
for _, ip := range ips {
if ip.IP.To4() == nil { // IPv6 address found
ipv6Address := net.JoinHostPort(ip.IP.String(), port)
client.dnsCache.Set(address, ipv6Address)
return dialer.Dial(network, ipv6Address)
}
}
return nil, fmt.Errorf("no ipv6 address found")
}
}
return &client
}
func (r *HTTPClient) Do(req *http.Request) (*http.Response, error) {
if r.token != "" && req.Header.Get("Authorization") == "" {
req.Header.Set("Authorization", "Bearer "+r.token)
}
var resp *http.Response
var err error
attempt := 0
var origBody []byte
if req.Method == "POST" || req.Method == "PUT" || req.Method == "PATCH" {
origBody, _ = io.ReadAll(req.Body)
}
for {
if origBody != nil {
req.Body = io.NopCloser(bytes.NewReader(origBody))
}
if resp != nil && resp.Body != nil {
resp.Body.Close()
}
if len(r.hosts) > 0 {
r.ensureReachableHost(req)
}
if r.rateLimiter != nil {
r.rateLimiter.Wait()
}
resp, err = r.client.Do(req)
// http 4xx and 5xx errors
if resp != nil && resp.StatusCode >= http.StatusBadRequest {
body, _ := io.ReadAll(resp.Body)
if req.Host == "api.real-debrid.com" {
// api servers
if body != nil {
var errResp ApiErrorResponse
jsonErr := json.Unmarshal(body, &errResp)
if jsonErr == nil {
errResp.Message += fmt.Sprintf(" (status code: %d)", resp.StatusCode)
} else {
errResp.Message = string(body)
errResp.Code = resp.StatusCode
}
err = &errResp
}
} else {
// download servers
err = &DownloadErrorResponse{
Message: resp.Header.Get("X-Error"),
Code: resp.StatusCode,
}
}
}
if !r.shouldRetry(req, resp, err, attempt, r.rateLimitSleepSecs) {
break
}
attempt++
continue
}
return resp, err
}
// ensureReachableHost ensures that the request is sent to a reachable host
// if not, it will replace the host with a reachable one
func (r *HTTPClient) ensureReachableHost(req *http.Request) {
// skip if not a download server
if !strings.Contains(req.Host, ".download.real-debrid.") {
return
}
// skip CDN servers
if req.Host[0] >= 'a' && req.Host[0] <= 'z' {
return
}
// check if req.Host is in r.hosts
if r.CheckIfHostIsReachable(req.Host) {
return
}
// replace prefix of req.Host from .com to .cloud or vice versa
var newHost string
if strings.HasSuffix(req.Host, ".com") {
newHost = strings.Replace(req.Host, ".com", ".cloud", 1)
} else if strings.HasSuffix(req.Host, ".cloud") {
newHost = strings.Replace(req.Host, ".cloud", ".com", 1)
}
// check if newHost is reachable
if r.CheckIfHostIsReachable(newHost) {
req.Host = newHost
req.URL.Host = req.Host
return
}
// // just pick a random host
// req.Host = r.hosts[rand.Intn(len(r.hosts))]
// req.URL.Host = req.Host
// just retain the original host if not in the list of reachable hosts
r.log.Debugf("Host %s is not found on the list of reachable hosts", req.Host)
}
// CheckIfHostIsReachable checks if the given host is passed in the list of reachable hosts
func (r *HTTPClient) CheckIfHostIsReachable(reqHost string) bool {
for _, host := range r.hosts {
if reqHost == host {
return true
}
}
return false
}
func (r *HTTPClient) proxyDialer(proxyURL *url.URL) (proxy.Dialer, error) {
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
httpProxyDialer := http_dialer.New(proxyURL, http_dialer.WithConnectionTimeout(time.Duration(r.timeoutSecs)*time.Second))
return httpProxyDialer, nil
} else if proxyURL.Scheme == "socks5" {
return proxy.FromURL(proxyURL, proxy.Direct)
}
return nil, fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
}
// shouldRetry returns true if the request should be retried
func (r *HTTPClient) shouldRetry(req *http.Request, resp *http.Response, err error, attempts, rateLimitSleep int) bool {
// assume that all addMagnet requests are always successful;
// don't retry to prevent duplicate torrents
if req.Host == "api.real-debrid.com" && strings.HasSuffix(req.URL.Path, "torrents/addMagnet") {
return false
}
if apiErr, ok := err.(*ApiErrorResponse); ok {
switch apiErr.Code {
case 5: // Slow down (retry infinitely)
secs := r.backoff(attempts, rateLimitSleep)
r.log.Debugf("API rate limit reached, attempt #%d, retrying in %d seconds", attempts+1, secs/time.Second)
time.Sleep(secs)
return true
case 34: // Too many requests (retry infinitely)
secs := r.backoff(attempts, rateLimitSleep)
r.log.Debugf("API rate limit reached, attempt #%d, retrying in %d seconds", attempts+1, secs/time.Second)
time.Sleep(secs)
return true
case 36: // Fair Usage Limit
secs := r.backoff(attempts, rateLimitSleep)
r.log.Debugf("Fair usage limit reached, attempt #%d, retrying in %d seconds", attempts+1, secs/time.Second)
time.Sleep(secs)
return true
case -1: // Internal error
secs := r.backoff(attempts, 1)
r.log.Debugf("RD Internal error, attempt #%d, retrying in %d seconds", attempts+1, secs/time.Second)
time.Sleep(secs)
return true
case 429:
secs := r.backoff(attempts, rateLimitSleep)
r.log.Debugf("API rate limit reached, attempt #%d, retrying in %d seconds", attempts+1, secs/time.Second)
time.Sleep(secs)
return true
case 503:
secs := r.backoff(attempts, rateLimitSleep)
r.log.Debugf("RD Service Unavailable, attempt #%d, retrying in %d seconds", attempts+1, secs/time.Second)
time.Sleep(secs)
return true
default:
return false
}
} else if downloadErr, ok := err.(*DownloadErrorResponse); ok {
switch downloadErr.Message {
case "invalid_download_code": // 404
if attempts >= r.maxRetries {
r.log.Debugf("Invalid download code, attempt #%d", attempts+1)
return false
}
secs := r.backoff(attempts, rateLimitSleep)
r.log.Debugf("Invalid download code, attempt #%d, retrying in %d seconds", attempts+1, secs/time.Second)
time.Sleep(secs)
return true
default:
return false
}
}
// succesful requests
if resp != nil {
okResponseCode := resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusPartialContent
// if the request has a Range header but the server doesn't respond with a Content-Range header
hasRangeHeader := req.Header.Get("Range") != "" && !strings.HasPrefix(req.Header.Get("Range"), "bytes=0-")
if okResponseCode && hasRangeHeader && resp.Header.Get("Content-Range") == "" {
return true
}
return false
}
if attempts >= r.maxRetries {
r.log.Debugf("Request failed, attempt #%d (error=%v)", attempts+1, err)
return false
}
secs := r.backoff(attempts, 1)
r.log.Debugf("Request failed, attempt #%d, retrying in %d seconds (error=%v)", attempts+1, secs/time.Second, err)
time.Sleep(secs)
return true
}
func backoffFunc(attempt, base int) time.Duration {
maxDuration := 60
backoff := base * int(math.Pow(2, float64(attempt)))
if backoff > maxDuration {
backoff = maxDuration
}
return time.Duration(backoff) * time.Second
}
func (r *HTTPClient) VerifyLink(link string) error {
req, err := http.NewRequest(http.MethodHead, link, nil)
if err != nil {
return err
}
timeout := time.Duration(r.timeoutSecs) * time.Second
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
req = req.WithContext(ctx)
resp, err := r.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
return nil
}