Files
zurg/pkg/http/client.go
2024-01-08 21:13:35 +01:00

186 lines
4.4 KiB
Go

package http
import (
"context"
"encoding/json"
"fmt"
"io"
"math"
"net"
"net/http"
"strings"
"time"
"github.com/debridmediamanager/zurg/internal/config"
"github.com/debridmediamanager/zurg/pkg/logutil"
cmap "github.com/orcaman/concurrent-map/v2"
)
const (
RATE_LIMIT_FACTOR = 4 // should always be > 1
)
type HTTPClient struct {
client *http.Client
maxRetries int
backoff func(attempt int) time.Duration
getRetryIncr func(resp *http.Response, hasRangeHeader bool, err error) int
bearerToken string
cfg config.ConfigInterface
ipv6 cmap.ConcurrentMap[string, string]
log *logutil.Logger
}
// {
// "error": "infringing_file",
// "error_code": 35
// }
type ErrorResponse struct {
Message string `json:"error"`
Code int `json:"error_code"`
}
func (e *ErrorResponse) Error() string {
return fmt.Sprintf("api response error: %s (code: %d)", e.Message, e.Code)
}
func NewHTTPClient(token string, maxRetries int, timeoutSecs int, cfg config.ConfigInterface, log *logutil.Logger) *HTTPClient {
client := HTTPClient{
bearerToken: token,
client: &http.Client{
Timeout: time.Duration(timeoutSecs) * time.Second,
},
maxRetries: maxRetries * RATE_LIMIT_FACTOR,
backoff: func(attempt int) time.Duration {
maxDuration := 60
backoff := int(math.Pow(2, float64(attempt)))
if backoff > maxDuration {
backoff = maxDuration
}
return time.Duration(backoff) * time.Second
},
getRetryIncr: func(resp *http.Response, hasRangeHeader bool, err error) int {
if resp != nil {
if resp.StatusCode == 429 {
return 1
}
if resp.StatusCode == http.StatusOK && hasRangeHeader {
return 1
}
return 0 // don't retry
} else if err != nil {
log.Errorf("Client request error: %s", err.Error())
if strings.Contains(err.Error(), "api response error") {
if apiErr, ok := err.(*ErrorResponse); ok {
switch apiErr.Code {
case -1: // Internal error
return 1
case 5: // Slow down
return 1
case 6: // Ressource unreachable
return 1
case 17: // Hoster in maintenance
return 1
case 19: // Hoster temporarily unavailable
return 1
case 25: // Service unavailable
return 1
case 34: // Too many requests
return 1
case 36: // Fair Usage Limit
return 1
default:
return 0 // don't retry
}
}
}
return 1
}
return RATE_LIMIT_FACTOR
},
cfg: cfg,
ipv6: cmap.New[string](),
log: log,
}
if cfg.ShouldForceIPv6() {
dialer := &net.Dialer{}
dialContext := func(ctx context.Context, network, address string) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
if ipv6Address, ok := client.ipv6.Get(address); ok {
return dialer.DialContext(ctx, network, ipv6Address)
}
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil {
return nil, err
}
for _, ip := range ips {
if ip.IP.To4() == nil { // IPv6 address found
ip6Host := ip.IP.String()
ipv6Address := net.JoinHostPort(ip6Host, port)
client.ipv6.Set(address, ipv6Address)
return dialer.DialContext(ctx, network, ipv6Address)
}
}
return dialer.DialContext(ctx, network, address)
}
transport := &http.Transport{
DialContext: dialContext,
}
client.client.Transport = transport
}
return &client
}
func (r *HTTPClient) Do(req *http.Request) (*http.Response, error) {
if r.bearerToken != "" {
req.Header.Set("Authorization", "Bearer "+r.bearerToken)
}
// check if Range header is set
hasRangeHeader := req.Header.Get("Range") != ""
var resp *http.Response
var err error
attempt := 0
for {
resp, err = r.client.Do(req)
if resp != nil && resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
body, _ := io.ReadAll(resp.Body)
if body != nil {
var errResp ErrorResponse
jsonErr := json.Unmarshal(body, &errResp)
if jsonErr == nil {
err = &errResp
}
}
}
if incr := r.getRetryIncr(resp, hasRangeHeader, err); incr > 0 {
attempt += incr
if attempt > r.maxRetries {
break
}
if incr >= RATE_LIMIT_FACTOR {
time.Sleep(r.backoff(attempt))
} else {
time.Sleep(time.Duration(r.cfg.GetRateLimitSleepSeconds()) * time.Second) // extra delay
}
if resp != nil {
resp.Body.Close()
}
} else {
// if incr == 0, don't retry anymore
break
}
}
return resp, err
}