diff --git a/server/ctrl/webdav.go b/server/ctrl/webdav.go index c554fe007..9856c5001 100644 --- a/server/ctrl/webdav.go +++ b/server/ctrl/webdav.go @@ -1,12 +1,14 @@ package ctrl import ( - . "github.com/mickael-kerjean/filestash/server/common" - "github.com/mickael-kerjean/filestash/server/model" - "github.com/mickael-kerjean/net/webdav" "net/http" "path/filepath" "strings" + + . "github.com/mickael-kerjean/filestash/server/common" + "github.com/mickael-kerjean/filestash/server/middleware" + "github.com/mickael-kerjean/filestash/server/model" + "github.com/mickael-kerjean/net/webdav" ) func WebdavHandler(ctx *App, res http.ResponseWriter, req *http.Request) { @@ -53,8 +55,8 @@ func WebdavHandler(ctx *App, res http.ResponseWriter, req *http.Request) { * an imbecile and considering we can't even see the source code they are running, the best approach we * could go on is: "crap in, crap out" where useless request coming in are identified and answer appropriatly */ -func WebdavBlacklist(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) { - return func(ctx *App, res http.ResponseWriter, req *http.Request) { +func WebdavBlacklist(fn middleware.HandlerFunc) middleware.HandlerFunc { + return middleware.HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) { base := filepath.Base(req.URL.String()) if req.Method == "PUT" || req.Method == "MKCOL" { @@ -125,5 +127,5 @@ func WebdavBlacklist(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx } } fn(ctx, res, req) - } + }) } diff --git a/server/middleware/context.go b/server/middleware/context.go index f1f350455..cb76b5cd5 100644 --- a/server/middleware/context.go +++ b/server/middleware/context.go @@ -9,7 +9,7 @@ import ( "strings" ) -func BodyParser(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) { +func BodyParser(fn HandlerFunc) HandlerFunc { extractBody := func(req *http.Request) (map[string]interface{}, error) { body := map[string]interface{}{} byt, err := ioutil.ReadAll(req.Body) @@ -25,14 +25,14 @@ func BodyParser(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App return body, nil } - return func(ctx *App, res http.ResponseWriter, req *http.Request) { + return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) { var err error if ctx.Body, err = extractBody(req); err != nil { SendErrorResult(res, ErrNotValid) return } fn(ctx, res, req) - } + }) } func GenerateRequestID(prefix string) string { diff --git a/server/middleware/http.go b/server/middleware/http.go index cedac00dc..39bcdcc08 100644 --- a/server/middleware/http.go +++ b/server/middleware/http.go @@ -10,8 +10,8 @@ import ( "strings" ) -func ApiHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) { - return func(ctx *App, res http.ResponseWriter, req *http.Request) { +func ApiHeaders(fn HandlerFunc) HandlerFunc { + return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) { header := res.Header() header.Set("Content-Type", "application/json") header.Set("Cache-Control", "no-cache") @@ -20,20 +20,20 @@ func ApiHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App header.Set("X-Request-ID", GenerateRequestID("API")) } fn(ctx, res, req) - } + }) } -func StaticHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) { - return func(ctx *App, res http.ResponseWriter, req *http.Request) { +func StaticHeaders(fn HandlerFunc) HandlerFunc { + return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) { header := res.Header() header.Set("Content-Type", GetMimeType(filepath.Ext(req.URL.Path))) header.Set("Cache-Control", "max-age=2592000") fn(ctx, res, req) - } + }) } -func IndexHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) { - return func(ctx *App, res http.ResponseWriter, req *http.Request) { +func IndexHeaders(fn HandlerFunc) HandlerFunc { + return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) { header := res.Header() header.Set("Content-Type", "text/html") header.Set("Cache-Control", "no-cache") @@ -65,11 +65,11 @@ func IndexHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *A } // header.Set("Content-Security-Policy", cspHeader) fn(ctx, res, req) - } + }) } -func SecureHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) { - return func(ctx *App, res http.ResponseWriter, req *http.Request) { +func SecureHeaders(fn HandlerFunc) HandlerFunc { + return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) { header := res.Header() if Config.Get("general.force_ssl").Bool() { header.Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload") @@ -77,11 +77,11 @@ func SecureHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx * header.Set("X-Content-Type-Options", "nosniff") header.Set("X-XSS-Protection", "1; mode=block") fn(ctx, res, req) - } + }) } -func SecureOrigin(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) { - return func(ctx *App, res http.ResponseWriter, req *http.Request) { +func SecureOrigin(fn HandlerFunc) HandlerFunc { + return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) { if host := Config.Get("general.host").String(); host != "" { host = strings.TrimPrefix(host, "http://") host = strings.TrimPrefix(host, "https://") @@ -105,11 +105,11 @@ func SecureOrigin(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *A Log.Warning("Intrusion detection: %s - %s", RetrievePublicIp(req), req.URL.String()) SendErrorResult(res, ErrNotAllowed) - } + }) } -func WithPublicAPI(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) { - return func(ctx *App, res http.ResponseWriter, req *http.Request) { +func WithPublicAPI(fn HandlerFunc) HandlerFunc { + return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) { apiKey := req.URL.Query().Get("key") if apiKey == "" { fn(ctx, res, req) @@ -132,13 +132,13 @@ func WithPublicAPI(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx * return } fn(ctx, res, req) - } + }) } var limiter = rate.NewLimiter(10, 1000) -func RateLimiter(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) { - return func(ctx *App, res http.ResponseWriter, req *http.Request) { +func RateLimiter(fn HandlerFunc) HandlerFunc { + return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) { if limiter.Allow() == false { Log.Warning("middleware::http::ratelimit too many requests") SendErrorResult( @@ -148,7 +148,7 @@ func RateLimiter(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *Ap return } fn(ctx, res, req) - } + }) } func EnableCors(req *http.Request, res http.ResponseWriter, host string) error { diff --git a/server/middleware/index.go b/server/middleware/index.go index fcfccc5c0..db7c1af79 100644 --- a/server/middleware/index.go +++ b/server/middleware/index.go @@ -1,15 +1,13 @@ package middleware import ( - "bytes" - "encoding/json" . "github.com/mickael-kerjean/filestash/server/common" "net/http" - "sync" "time" ) -var telemetry = Telemetry{Data: make([]LogEntry, 0)} +type HandlerFunc func(*App, http.ResponseWriter, *http.Request) +type Middleware func(HandlerFunc) HandlerFunc func init() { Hooks.Register.Onload(func() { @@ -22,10 +20,7 @@ func init() { }) } -type Middleware func(func(*App, http.ResponseWriter, *http.Request)) func(*App, http.ResponseWriter, *http.Request) - -func NewMiddlewareChain(fn func(*App, http.ResponseWriter, *http.Request), m []Middleware, app App) http.HandlerFunc { - +func NewMiddlewareChain(fn HandlerFunc, m []Middleware, app App) http.HandlerFunc { return func(res http.ResponseWriter, req *http.Request) { var resw ResponseWriter = NewResponseWriter(res) var f func(*App, http.ResponseWriter, *http.Request) = fn @@ -37,7 +32,7 @@ func NewMiddlewareChain(fn func(*App, http.ResponseWriter, *http.Request), m []M if req.Body != nil { req.Body.Close() } - go Logger(app, &resw, req) + go logger(app, &resw, req) } } @@ -65,112 +60,3 @@ func (w *ResponseWriter) Write(b []byte) (int, error) { } return w.ResponseWriter.Write(b) } - -type LogEntry struct { - Host string `json:"host"` - Method string `json:"method"` - RequestURI string `json:"pathname"` - Proto string `json:"proto"` - Status int `json:"status"` - Scheme string `json:"scheme"` - UserAgent string `json:"userAgent"` - Ip string `json:"ip"` - Referer string `json:"referer"` - Duration float64 `json:"responseTime"` - Version string `json:"version"` - Backend string `json:"backend"` - Share string `json:"share"` - License string `json:"license"` - Session string `json:"session"` - RequestID string `json:"requestID"` -} - -func Logger(ctx App, res http.ResponseWriter, req *http.Request) { - if obj, ok := res.(*ResponseWriter); ok && req.RequestURI != "/about" { - point := LogEntry{ - Version: APP_VERSION + "." + BUILD_DATE, - License: LICENSE, - Scheme: req.URL.Scheme, - Host: req.Host, - Method: req.Method, - RequestURI: req.RequestURI, - Proto: req.Proto, - Status: obj.status, - UserAgent: req.Header.Get("User-Agent"), - Ip: req.RemoteAddr, - Referer: req.Referer(), - Duration: float64(time.Now().Sub(obj.start)) / (1000 * 1000), - Backend: func() string { - if ctx.Session["type"] == "" { - return "null" - } - return ctx.Session["type"] - }(), - Share: func() string { - if ctx.Share.Id == "" { - return "null" - } - return ctx.Share.Id - }(), - Session: func() string { - if ctx.Session["type"] == "" { - return "null" - } - return GenerateID(&ctx) - }(), - RequestID: func() string { - defer func() string { - if r := recover(); r != nil { - return "oops" - } - return "null" - }() - return res.Header().Get("X-Request-ID") - }(), - } - if Config.Get("log.telemetry").Bool() { - telemetry.Record(point) - } - if Config.Get("log.enable").Bool() { - Log.Stdout("HTTP %3d %3s %6.1fms %s", point.Status, point.Method, point.Duration, point.RequestURI) - } - } -} - -type Telemetry struct { - Data []LogEntry - mu sync.Mutex -} - -func (this *Telemetry) Record(point LogEntry) { - this.mu.Lock() - this.Data = append(this.Data, point) - this.mu.Unlock() -} - -func (this *Telemetry) Flush() { - if len(this.Data) == 0 { - return - } - this.mu.Lock() - pts := this.Data - this.Data = make([]LogEntry, 0) - this.mu.Unlock() - - body, err := json.Marshal(pts) - if err != nil { - return - } - r, err := http.NewRequest("POST", "https://downloads.filestash.app/event", bytes.NewReader(body)) - r.Header.Set("Connection", "Close") - r.Header.Set("Content-Type", "application/json") - r.Close = true - if err != nil { - return - } - resp, err := HTTP.Do(r) - if err != nil { - return - } - resp.Body.Close() -} diff --git a/server/middleware/session.go b/server/middleware/session.go index 67e5a48fb..c0f72f34a 100644 --- a/server/middleware/session.go +++ b/server/middleware/session.go @@ -14,18 +14,18 @@ import ( "time" ) -func LoggedInOnly(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) { - return func(ctx *App, res http.ResponseWriter, req *http.Request) { +func LoggedInOnly(fn HandlerFunc) HandlerFunc { + return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) { if ctx.Backend == nil || ctx.Session == nil { SendErrorResult(res, ErrPermissionDenied) return } fn(ctx, res, req) - } + }) } -func AdminOnly(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) { - return func(ctx *App, res http.ResponseWriter, req *http.Request) { +func AdminOnly(fn HandlerFunc) HandlerFunc { + return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) { if admin := Config.Get("auth.admin").String(); admin != "" { c, err := req.Cookie(COOKIE_NAME_ADMIN) if err != nil { @@ -47,11 +47,11 @@ func AdminOnly(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, } } fn(ctx, res, req) - } + }) } -func SessionStart(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) { - return func(ctx *App, res http.ResponseWriter, req *http.Request) { +func SessionStart(fn HandlerFunc) HandlerFunc { + return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) { var err error if ctx.Share, err = _extractShare(req); err != nil { @@ -72,21 +72,21 @@ func SessionStart(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *A return } fn(ctx, res, req) - } + }) } -func SessionTry(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) { - return func(ctx *App, res http.ResponseWriter, req *http.Request) { +func SessionTry(fn HandlerFunc) HandlerFunc { + return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) { ctx.Share, _ = _extractShare(req) ctx.Authorization = _extractAuthorization(req) ctx.Session, _ = _extractSession(req, ctx) ctx.Backend, _ = _extractBackend(req, ctx) fn(ctx, res, req) - } + }) } -func RedirectSharedLoginIfNeeded(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) { - return func(ctx *App, res http.ResponseWriter, req *http.Request) { +func RedirectSharedLoginIfNeeded(fn HandlerFunc) HandlerFunc { + return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) { share_id := _extractShareId(req) if share_id == "" { if mux.Vars(req)["share"] == "private" { @@ -103,11 +103,11 @@ func RedirectSharedLoginIfNeeded(fn func(*App, http.ResponseWriter, *http.Reques return } fn(ctx, res, req) - } + }) } -func CanManageShare(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) { - return func(ctx *App, res http.ResponseWriter, req *http.Request) { +func CanManageShare(fn HandlerFunc) HandlerFunc { + return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) { share_id := mux.Vars(req)["share"] if share_id == "" { Log.Debug("middleware::session::share 'invalid share id'") @@ -167,7 +167,7 @@ func CanManageShare(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx } SendErrorResult(res, ErrPermissionDenied) return - } + }) } func _extractAuthorization(req *http.Request) (token string) { diff --git a/server/middleware/telemetry.go b/server/middleware/telemetry.go new file mode 100644 index 000000000..b7e495009 --- /dev/null +++ b/server/middleware/telemetry.go @@ -0,0 +1,122 @@ +package middleware + +import ( + "bytes" + "encoding/json" + "net/http" + "sync" + "time" + + . "github.com/mickael-kerjean/filestash/server/common" +) + +var telemetry = Telemetry{Data: make([]LogEntry, 0)} + +type Telemetry struct { + Data []LogEntry + mu sync.Mutex +} + +type LogEntry struct { + Host string `json:"host"` + Method string `json:"method"` + RequestURI string `json:"pathname"` + Proto string `json:"proto"` + Status int `json:"status"` + Scheme string `json:"scheme"` + UserAgent string `json:"userAgent"` + Ip string `json:"ip"` + Referer string `json:"referer"` + Duration float64 `json:"responseTime"` + Version string `json:"version"` + Backend string `json:"backend"` + Share string `json:"share"` + License string `json:"license"` + Session string `json:"session"` + RequestID string `json:"requestID"` +} + +func logger(ctx App, res http.ResponseWriter, req *http.Request) { + if obj, ok := res.(*ResponseWriter); ok && req.RequestURI != "/about" { + point := LogEntry{ + Version: APP_VERSION + "." + BUILD_DATE, + License: LICENSE, + Scheme: req.URL.Scheme, + Host: req.Host, + Method: req.Method, + RequestURI: req.RequestURI, + Proto: req.Proto, + Status: obj.status, + UserAgent: req.Header.Get("User-Agent"), + Ip: req.RemoteAddr, + Referer: req.Referer(), + Duration: float64(time.Now().Sub(obj.start)) / (1000 * 1000), + Backend: func() string { + if ctx.Session["type"] == "" { + return "null" + } + return ctx.Session["type"] + }(), + Share: func() string { + if ctx.Share.Id == "" { + return "null" + } + return ctx.Share.Id + }(), + Session: func() string { + if ctx.Session["type"] == "" { + return "null" + } + return GenerateID(&ctx) + }(), + RequestID: func() string { + defer func() string { + if r := recover(); r != nil { + return "oops" + } + return "null" + }() + return res.Header().Get("X-Request-ID") + }(), + } + if Config.Get("log.telemetry").Bool() { + telemetry.Record(point) + } + if Config.Get("log.enable").Bool() { + Log.Stdout("HTTP %3d %3s %6.1fms %s", point.Status, point.Method, point.Duration, point.RequestURI) + } + } +} + +func (this *Telemetry) Record(point LogEntry) { + this.mu.Lock() + this.Data = append(this.Data, point) + this.mu.Unlock() +} + +func (this *Telemetry) Flush() { + if len(this.Data) == 0 { + return + } + this.mu.Lock() + pts := this.Data + this.Data = make([]LogEntry, 0) + this.mu.Unlock() + + body, err := json.Marshal(pts) + if err != nil { + return + } + r, err := http.NewRequest("POST", "https://downloads.filestash.app/event", bytes.NewReader(body)) + r.Header.Set("Connection", "Close") + r.Header.Set("Content-Type", "application/json") + r.Close = true + if err != nil { + return + } + resp, err := HTTP.Do(r) + if err != nil { + return + } + resp.Body.Close() +}