commit da2c53bf86a8ca7f6183f286f4e8765c219eafc8 Author: Ben Sarmiento Date: Mon Oct 16 21:31:51 2023 +0200 Initial commit :rainbow: diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3b735ec --- /dev/null +++ b/.gitignore @@ -0,0 +1,21 @@ +# If you prefer the allow list template instead of the deny list, see community template: +# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore +# +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Go workspace file +go.work diff --git a/README.md b/README.md new file mode 100644 index 0000000..b30e393 --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +# zurg diff --git a/cmd/zurg/main.go b/cmd/zurg/main.go new file mode 100644 index 0000000..c5d6c51 --- /dev/null +++ b/cmd/zurg/main.go @@ -0,0 +1,26 @@ +package main + +import ( + "log" + "net/http" + "os" + + "github.com/debridmediamanager.com/zurg/internal/dav" + "github.com/debridmediamanager.com/zurg/pkg/repo" +) + +func main() { + mux := http.NewServeMux() + db, dbErr := repo.NewDatabase(os.Getenv("DB_DSN")) + if dbErr != nil { + log.Println(dbErr) + } + + dav.Router(mux, db) + + log.Println("Listening on port 8123...") + err := http.ListenAndServe(":8123", mux) + if err != nil { + log.Println(err) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..7148f83 --- /dev/null +++ b/go.mod @@ -0,0 +1,14 @@ +module github.com/debridmediamanager.com/zurg + +go 1.21.3 + +require ( + github.com/cespare/xxhash/v2 v2.1.2 // indirect + github.com/go-sql-driver/mysql v1.7.1 // indirect + github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect + github.com/golang/protobuf v1.5.3 // indirect + github.com/klauspost/cpuid/v2 v2.0.9 // indirect + github.com/qianbin/directcache v0.9.7 // indirect + github.com/zeebo/xxh3 v1.0.2 // indirect + google.golang.org/protobuf v1.26.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..a16b321 --- /dev/null +++ b/go.sum @@ -0,0 +1,26 @@ +github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= +github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= +github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/qianbin/directcache v0.9.7 h1:DH6MdmU0fVjcKry57ju7U6akTFDBnLhHd0xOHZDq948= +github.com/qianbin/directcache v0.9.7/go.mod h1:gZBpa9NqO1Qz7wZKO7t7atBA76bT8X0eM01PdveW4qc= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= +github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/internal/dav/response.go b/internal/dav/response.go new file mode 100644 index 0000000..c28ec5e --- /dev/null +++ b/internal/dav/response.go @@ -0,0 +1,99 @@ +package dav + +import ( + "os" + "path/filepath" + + "github.com/debridmediamanager.com/zurg/pkg/dav" + "github.com/debridmediamanager.com/zurg/pkg/realdebrid" + "github.com/debridmediamanager.com/zurg/pkg/repo" +) + +func createMultiTorrentResponse(torrents []realdebrid.Torrent) (*dav.MultiStatus, error) { + var responses []dav.Response + + // initial response is the directory itself + responses = append(responses, dav.Directory("/torrents")) + + // add all files and directories in the directory + for _, item := range torrents { + if item.Progress != 100 { + continue + } + + path := filepath.Join("/torrents", item.Filename) + responses = append(responses, dav.Directory(path)) + } + + return &dav.MultiStatus{ + XMLNS: "DAV:", + Response: responses, + }, nil +} + +func createSingleTorrentResponse(torrent realdebrid.Torrent, db *repo.Database) (*dav.MultiStatus, error) { + var responses []dav.Response + + // initial response is the directory itself + currentPath := filepath.Join("/torrents", torrent.Filename) + responses = append(responses, dav.Directory(currentPath)) + + davFiles, err := db.GetMultiple(torrent.Hash) + if err != nil { + return nil, err + } + + // Create a map for O(1) lookups of the cached links + cachedLinksMap := make(map[string]*repo.DavFile) + for _, u := range davFiles.Files { + cachedLinksMap[u.Link] = u + } + for _, link := range torrent.Links { + if u, exists := cachedLinksMap[link]; exists { + if u.Filesize == 0 { + // This link is cached but the filesize is 0 + // This means that the link is dead + continue + } + path := filepath.Join(currentPath, u.Filename) + response := dav.File( + path, + int(u.Filesize), + torrent.Added, // Assuming you want to use the torrent added time here + ) + responses = append(responses, response) + } else { + // This link is not cached yet + unrestrictFn := func() (realdebrid.UnrestrictResponse, error) { + return realdebrid.UnrestrictCheck(os.Getenv("RD_TOKEN"), link) + } + unrestrictResponse := realdebrid.RetryUntilOk(unrestrictFn) + if unrestrictResponse == nil { + db.Insert(torrent.Hash, torrent.Filename, realdebrid.UnrestrictResponse{ + Filename: "", + Filesize: 0, + Link: link, + Host: "", + }) + continue + } else { + db.Insert(torrent.Hash, torrent.Filename, *unrestrictResponse) + } + + path := filepath.Join(currentPath, unrestrictResponse.Filename) + response := dav.File( + path, + int(unrestrictResponse.Filesize), + torrent.Added, + ) + responses = append(responses, response) + } + } + + // TODO: dedupe the links in the response + + return &dav.MultiStatus{ + XMLNS: "DAV:", + Response: responses, + }, nil +} diff --git a/internal/dav/router.go b/internal/dav/router.go new file mode 100644 index 0000000..674ce2b --- /dev/null +++ b/internal/dav/router.go @@ -0,0 +1,127 @@ +package dav + +import ( + "encoding/xml" + "fmt" + "log" + "net/http" + "os" + "path" + "strings" + + "github.com/debridmediamanager.com/zurg/pkg/dav" + "github.com/debridmediamanager.com/zurg/pkg/realdebrid" + "github.com/debridmediamanager.com/zurg/pkg/repo" +) + +func findTorrentByFilename(torrents []realdebrid.Torrent, filename string) *realdebrid.Torrent { + for _, torrent := range torrents { + if torrent.Filename == filename { + return &torrent + } + } + return nil +} + +func Router(mux *http.ServeMux, db *repo.Database) { + torrents, err := realdebrid.GetTorrents(os.Getenv("RD_TOKEN")) + if err != nil { + log.Printf("Cannot get torrents: %v", err.Error()) + return + } + + rootResponse := dav.MultiStatus{ + XMLNS: "DAV:", + Response: []dav.Response{ + dav.Directory("/"), + dav.Directory("/torrents"), + }, + } + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + requestPath := path.Clean(r.URL.Path) + + switch r.Method { + case "PROPFIND": + log.Println("PROPFIND", requestPath) + var output []byte + var err error + + if requestPath == "/" { + output, err = xml.MarshalIndent(rootResponse, "", " ") + } else if requestPath == "/torrents" { + var allTorrentsResponse *dav.MultiStatus + allTorrentsResponse, err = createMultiTorrentResponse(torrents) + if err != nil { + log.Printf("Cannot read directory: %v", err.Error()) + http.Error(w, fmt.Sprintf("Cannot read directory: %v", err.Error()), http.StatusInternalServerError) + return + } + output, err = xml.MarshalIndent(allTorrentsResponse, "", " ") + } else { + lastSegment := path.Base(requestPath) + torrent := findTorrentByFilename(torrents, lastSegment) + if torrent == nil { + log.Println("Cannot find directory") + http.Error(w, "Cannot find directory", http.StatusNotFound) + return + } + + var torrentResponse *dav.MultiStatus + torrentResponse, err = createSingleTorrentResponse(*torrent, db) + if err != nil { + log.Printf("Cannot read directory: %v", err.Error()) + http.Error(w, fmt.Sprintf("Cannot read directory: %v", err.Error()), http.StatusInternalServerError) + return + } + output, err = xml.MarshalIndent(torrentResponse, "", " ") + } + + if err != nil { + log.Printf("Cannot marshal xml: %v", err.Error()) + http.Error(w, fmt.Sprintf("Cannot marshal xml: %v", err.Error()), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "text/xml; charset=\"utf-8\"") + w.WriteHeader(http.StatusMultiStatus) + fmt.Fprintf(w, "\n%s\n", output) + + case http.MethodOptions: + log.Println("OPTIONS", requestPath) + w.WriteHeader(http.StatusOK) + + case http.MethodGet: + log.Println("GET", requestPath) + segments := strings.Split(requestPath, "/") + + // If there are less than 2 segments, return an error or adjust as needed + if len(segments) < 2 { + log.Println("Cannot find file") + http.Error(w, "Cannot find file", http.StatusNotFound) + } + + // Get the last two segments + secondLast := segments[len(segments)-2] + last := segments[len(segments)-1] + unrestrict, dbErr := db.Get(secondLast, last) + if dbErr != nil { + log.Printf("Cannot find file in db: %v", dbErr.Error()) + http.Error(w, fmt.Sprintf("Cannot find file in db: %v", dbErr.Error()), http.StatusInternalServerError) + return + } + + resp, err := realdebrid.UnrestrictLink(os.Getenv("RD_TOKEN"), unrestrict.Link) + if err != nil { + log.Printf("Cannot unrestrict link: %v", err.Error()) + http.Error(w, fmt.Sprintf("Cannot unrestrict link: %v", err.Error()), http.StatusInternalServerError) + return + } + http.Redirect(w, r, resp.Download, http.StatusFound) + + default: + log.Println("Method not implemented") + http.Error(w, "Method not implemented", http.StatusMethodNotAllowed) + } + }) +} diff --git a/pkg/dav/response.go b/pkg/dav/response.go new file mode 100644 index 0000000..ac53ec7 --- /dev/null +++ b/pkg/dav/response.go @@ -0,0 +1,28 @@ +package dav + +func Directory(path string) Response { + return Response{ + Href: customPathEscape(path), + Propstat: PropStat{ + Prop: Prop{ + ResourceType: ResourceType{Collection: &struct{}{}}, + }, + Status: "HTTP/1.1 200 OK", + }, + } +} + +func File(path string, fileSize int, added string) Response { + return Response{ + Href: customPathEscape(path), + Propstat: PropStat{ + Prop: Prop{ + ContentLength: fileSize, + IsHidden: 0, + CreationDate: added, + LastModified: added, + }, + Status: "HTTP/1.1 200 OK", + }, + } +} diff --git a/pkg/dav/types.go b/pkg/dav/types.go new file mode 100644 index 0000000..1458281 --- /dev/null +++ b/pkg/dav/types.go @@ -0,0 +1,31 @@ +package dav + +import "encoding/xml" + +type MultiStatus struct { + XMLName xml.Name `xml:"d:multistatus"` + XMLNS string `xml:"xmlns:d,attr"` + Response []Response `xml:"d:response"` +} + +type Response struct { + Href string `xml:"d:href"` + Propstat PropStat `xml:"d:propstat"` +} + +type PropStat struct { + Prop Prop `xml:"d:prop"` + Status string `xml:"d:status"` +} + +type Prop struct { + ResourceType ResourceType `xml:"d:resourcetype"` + ContentLength int `xml:"d:getcontentlength"` + CreationDate string `xml:"d:creationdate"` + LastModified string `xml:"d:getlastmodified"` + IsHidden int `xml:"d:ishidden"` +} + +type ResourceType struct { + Collection *struct{} `xml:"d:collection,omitempty"` +} diff --git a/pkg/dav/util.go b/pkg/dav/util.go new file mode 100644 index 0000000..4d4cdb3 --- /dev/null +++ b/pkg/dav/util.go @@ -0,0 +1,14 @@ +package dav + +import ( + "net/url" + "strings" +) + +func customPathEscape(input string) string { + segments := strings.Split(input, "/") + for i, segment := range segments { + segments[i] = url.PathEscape(segment) + } + return strings.Join(segments, "/") +} diff --git a/pkg/realdebrid/api.go b/pkg/realdebrid/api.go new file mode 100644 index 0000000..d8e5d46 --- /dev/null +++ b/pkg/realdebrid/api.go @@ -0,0 +1,172 @@ +package realdebrid + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strconv" +) + +func UnrestrictCheck(accessToken, link string) (UnrestrictResponse, error) { + data := url.Values{} + data.Set("link", link) + + req, err := http.NewRequest("POST", "https://api.real-debrid.com/rest/1.0/unrestrict/check", bytes.NewBufferString(data.Encode())) + if err != nil { + return UnrestrictResponse{}, err + } + + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return UnrestrictResponse{}, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return UnrestrictResponse{}, err + } + + if resp.StatusCode != http.StatusOK { + return UnrestrictResponse{}, fmt.Errorf("HTTP error: %s", resp.Status) + } + + var response UnrestrictResponse + err = json.Unmarshal(body, &response) + if err != nil { + return UnrestrictResponse{}, err + } + + return response, nil +} + +func UnrestrictLink(accessToken, link string) (UnrestrictResponse, error) { + data := url.Values{} + data.Set("link", link) + + req, err := http.NewRequest("POST", "https://api.real-debrid.com/rest/1.0/unrestrict/link", bytes.NewBufferString(data.Encode())) + if err != nil { + return UnrestrictResponse{}, err + } + + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return UnrestrictResponse{}, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return UnrestrictResponse{}, err + } + + if resp.StatusCode != http.StatusOK { + return UnrestrictResponse{}, fmt.Errorf("HTTP error: %s", resp.Status) + } + + var response UnrestrictResponse + err = json.Unmarshal(body, &response) + if err != nil { + return UnrestrictResponse{}, err + } + + return response, nil +} + +func GetTorrents(accessToken string) ([]Torrent, error) { + baseURL := "https://api.real-debrid.com/rest/1.0/torrents" + var allTorrents []Torrent + page := 1 + limit := 10 + + for { + params := url.Values{} + params.Set("page", fmt.Sprintf("%d", page)) + params.Set("limit", fmt.Sprintf("%d", limit)) + + reqURL := baseURL + "?" + params.Encode() + + req, err := http.NewRequest("GET", reqURL, nil) + if err != nil { + return nil, err + } + + req.Header.Set("Authorization", "Bearer "+accessToken) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("HTTP error: %s", resp.Status) + } + + var torrents []Torrent + decoder := json.NewDecoder(resp.Body) + err = decoder.Decode(&torrents) + if err != nil { + return nil, err + } + + allTorrents = append(allTorrents, torrents...) + + totalCountHeader := "10" // resp.Header.Get("x-total-count") + totalCount, err := strconv.Atoi(totalCountHeader) + if err != nil { + break + } + + if len(torrents) < limit || len(allTorrents) >= totalCount { + break + } + + page++ + } + + return deduplicateTorrents(allTorrents), nil +} + +func deduplicateTorrents(torrents []Torrent) []Torrent { + mappedTorrents := make(map[string]Torrent) + + for _, t := range torrents { + if existing, ok := mappedTorrents[t.Filename]; ok { + if existing.Hash == t.Hash { + // If hash is the same, combine the links + existing.Links = append(existing.Links, t.Links...) + mappedTorrents[t.Filename] = existing + } else { + // If hash is different, delete old entry and create two new entries + delete(mappedTorrents, t.Filename) + newKey1 := t.Filename + " - " + t.Hash[:4] + newKey2 := existing.Filename + " - " + existing.Hash[:4] + mappedTorrents[newKey1] = t + mappedTorrents[newKey2] = existing + } + } else { + mappedTorrents[t.Filename] = t + } + } + + // Convert the map back to a slice + deduplicated := make([]Torrent, 0, len(mappedTorrents)) + for _, value := range mappedTorrents { + deduplicated = append(deduplicated, value) + } + + return deduplicated +} diff --git a/pkg/realdebrid/types.go b/pkg/realdebrid/types.go new file mode 100644 index 0000000..2cf602b --- /dev/null +++ b/pkg/realdebrid/types.go @@ -0,0 +1,23 @@ +package realdebrid + +type FileJSON struct { + FileSize int `json:"filesize"` + Link string `json:"link"` +} + +type UnrestrictResponse struct { + Filename string `json:"filename"` + Filesize int64 `json:"filesize"` + Link string `json:"link"` + Host string `json:"host"` + Download string `json:"download,omitempty"` + Streamable int `json:"streamable,omitempty"` +} + +type Torrent struct { + Filename string `json:"filename"` + Hash string `json:"hash"` + Progress int `json:"progress"` + Added string `json:"added"` + Links []string `json:"links"` +} diff --git a/pkg/realdebrid/util.go b/pkg/realdebrid/util.go new file mode 100644 index 0000000..74ec045 --- /dev/null +++ b/pkg/realdebrid/util.go @@ -0,0 +1,23 @@ +package realdebrid + +import ( + "math" + "strings" + "time" +) + +func RetryUntilOk[T any](fn func() (T, error)) *T { + const initialDelay = 2 * time.Second + const maxDelay = 128 * time.Second + for i := 0; ; i++ { + result, err := fn() + if err == nil { + return &result + } + if strings.Contains(err.Error(), "404") { + return nil + } + delay := time.Duration(math.Min(float64(initialDelay*time.Duration(math.Pow(2, float64(i)))), float64(maxDelay))) + time.Sleep(delay) + } +} diff --git a/pkg/repo/mysql.go b/pkg/repo/mysql.go new file mode 100644 index 0000000..c7b36b7 --- /dev/null +++ b/pkg/repo/mysql.go @@ -0,0 +1,186 @@ +package repo + +import ( + "bytes" + "database/sql" + "encoding/gob" + "fmt" + "log" + "path" + + "github.com/debridmediamanager.com/zurg/pkg/realdebrid" + _ "github.com/go-sql-driver/mysql" + "github.com/qianbin/directcache" + "github.com/zeebo/xxh3" +) + +type Database struct { + Connection *sql.DB + Cache *directcache.Cache +} + +func GenerateID(directory, filename string) string { + fullPath := path.Join(directory, filename) + hash := xxh3.HashString(fullPath) + return fmt.Sprintf("%016x", hash) +} + +func NewDatabase(dsn string) (*Database, error) { + db, err := sql.Open("mysql", dsn) + if err != nil { + return nil, err + } + + cache := directcache.New(10 << 20) // This initializes a cache with 10 MB + + return &Database{Connection: db, Cache: cache}, nil +} + +func (db *Database) Insert(parentHash, directory string, resp realdebrid.UnrestrictResponse) { + // Generate the ID for the link + var id string + if resp.Filename == "" { + id = GenerateID(directory, resp.Link) + } else { + id = GenerateID(directory, resp.Filename) + } + // Check if the link already exists in the database + var exists int + err := db.Connection.QueryRow("SELECT COUNT(*) FROM Links WHERE ID = ?", id).Scan(&exists) + if err != nil { + log.Printf("failed to check existence: %v", err) + } + + // If link does not exist in the database, insert the new record + if exists == 0 { + _, err = db.Connection.Exec(` + INSERT INTO Links (ID, ParentHash, Directory, Filename, Filesize, Link, Host) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + id, + parentHash, + directory, + resp.Filename, + resp.Filesize, + resp.Link, + resp.Host, + ) + if err != nil { + log.Printf("failed to insert record: %v", err) + } + + // Clear cache for parentHash + db.Cache.Del([]byte(parentHash)) + } +} + +func (db *Database) Get(directory, filename string) (*DavFile, error) { + id := GenerateID(directory, filename) + data, ok := db.Cache.Get([]byte(id)) + if !ok { + resp, err := fetchFromDatabaseByID(db.Connection, id) + if err != nil { + return nil, err + } + + buffer := &bytes.Buffer{} + encoder := gob.NewEncoder(buffer) + if err := encoder.Encode(resp); err != nil { + return nil, err + } + + db.Cache.Set([]byte(id), buffer.Bytes()) + return resp, nil + } + + buffer := bytes.NewBuffer(data) + decoder := gob.NewDecoder(buffer) + var resp DavFile + if err := decoder.Decode(&resp); err != nil { + return nil, err + } + return &resp, nil +} + +func (db *Database) GetMultiple(parentHash string) (*DavFiles, error) { + key := []byte(parentHash) + data, ok := db.Cache.Get(key) + if !ok { + resps, err := fetchMultipleFromDatabase(db.Connection, parentHash) + if err != nil { + return nil, err + } + + buffer := &bytes.Buffer{} + encoder := gob.NewEncoder(buffer) + if err := encoder.Encode(resps); err != nil { + return nil, err + } + + db.Cache.Set(key, buffer.Bytes()) + return resps, nil + } + + buffer := bytes.NewBuffer(data) + decoder := gob.NewDecoder(buffer) + var resps DavFiles + if err := decoder.Decode(&resps); err != nil { + return nil, err + } + + return &resps, nil +} + +func fetchFromDatabaseByID(conn *sql.DB, id string) (*DavFile, error) { + log.Printf("fetching from database: %s", id) + var resp DavFile + + err := conn.QueryRow(` + SELECT Filename, Filesize, Link + FROM Links WHERE ID = ?`, + id, + ).Scan( + &resp.Filename, + &resp.Filesize, + &resp.Link, + ) + if err != nil { + if err == sql.ErrNoRows { + return &resp, nil + } + log.Printf("failed to fetch record: %v", err) + } + + return &resp, nil +} + +func fetchMultipleFromDatabase(conn *sql.DB, parentHash string) (*DavFiles, error) { + log.Printf("fetching multiple from database: %s", parentHash) + rows, err := conn.Query(` + SELECT Filename, Filesize, Link + FROM Links WHERE ParentHash = ?`, + parentHash, + ) + if err != nil { + return nil, fmt.Errorf("failed to fetch records: %v", err) + } + defer rows.Close() + + var responses []*DavFile + + for rows.Next() { + resp := &DavFile{} + if err := rows.Scan(&resp.Filename, &resp.Filesize, &resp.Link); err != nil { + return nil, fmt.Errorf("failed to scan row: %v", err) + } + responses = append(responses, resp) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error while iterating over rows: %v", err) + } + + result := &DavFiles{ + Files: responses, + } + + return result, nil +} diff --git a/pkg/repo/types.go b/pkg/repo/types.go new file mode 100644 index 0000000..01d36c4 --- /dev/null +++ b/pkg/repo/types.go @@ -0,0 +1,11 @@ +package repo + +type DavFile struct { + Filename string + Filesize int64 + Link string +} + +type DavFiles struct { + Files []*DavFile +}