From 0838feeb8282398d024432818d8384b009a457b9 Mon Sep 17 00:00:00 2001 From: Ovear Date: Fri, 2 Dec 2022 17:59:59 +0800 Subject: [PATCH] =?UTF-8?q?fix=EF=BC=9Aintroduce=20buffered=20response=20w?= =?UTF-8?q?riter=20for=20webdav,=20fix=20status/error=20return=20failed.?= =?UTF-8?q?=20(#2544)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: introduce buffered response writer for webdav, fix webdav status/error return failed. * fix: bypass buffered writer for GET/HEAD/POST requests --- server/webdav/buffered_response_writer.go | 46 +++++++++++++++++++++++ server/webdav/webdav.go | 23 +++++++----- 2 files changed, 60 insertions(+), 9 deletions(-) create mode 100644 server/webdav/buffered_response_writer.go diff --git a/server/webdav/buffered_response_writer.go b/server/webdav/buffered_response_writer.go new file mode 100644 index 00000000..ed653eae --- /dev/null +++ b/server/webdav/buffered_response_writer.go @@ -0,0 +1,46 @@ +package webdav + +import ( + "net/http" +) + +type bufferedResponseWriter struct { + statusCode int + data []byte + header http.Header +} + +func (w *bufferedResponseWriter) Header() http.Header { + if w.header == nil { + w.header = make(http.Header) + } + return w.header +} + +func (w *bufferedResponseWriter) Write(bytes []byte) (int, error) { + w.data = append(w.data, bytes...) + return len(bytes), nil +} + +func (w *bufferedResponseWriter) WriteHeader(statusCode int) { + if w.statusCode == 0 { + w.statusCode = statusCode + } +} + +func (w *bufferedResponseWriter) WriteToResponse(rw http.ResponseWriter) (int, error) { + h := rw.Header() + for k, vs := range w.header { + for _, v := range vs { + h.Add(k, v) + } + } + rw.WriteHeader(w.statusCode) + return rw.Write(w.data) +} + +func newBufferedResponseWriter() *bufferedResponseWriter { + return &bufferedResponseWriter{ + statusCode: 0, + } +} diff --git a/server/webdav/webdav.go b/server/webdav/webdav.go index ac9d975a..7ab8dbf5 100644 --- a/server/webdav/webdav.go +++ b/server/webdav/webdav.go @@ -45,30 +45,33 @@ func (h *Handler) stripPrefix(p string) (string, int, error) { func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { status, err := http.StatusBadRequest, errUnsupportedMethod + brw := newBufferedResponseWriter() + useBufferedWriter := true if h.LockSystem == nil { status, err = http.StatusInternalServerError, errNoLockSystem } else { switch r.Method { case "OPTIONS": - status, err = h.handleOptions(w, r) + status, err = h.handleOptions(brw, r) case "GET", "HEAD", "POST": + useBufferedWriter = false status, err = h.handleGetHeadPost(w, r) case "DELETE": - status, err = h.handleDelete(w, r) + status, err = h.handleDelete(brw, r) case "PUT": - status, err = h.handlePut(w, r) + status, err = h.handlePut(brw, r) case "MKCOL": - status, err = h.handleMkcol(w, r) + status, err = h.handleMkcol(brw, r) case "COPY", "MOVE": - status, err = h.handleCopyMove(w, r) + status, err = h.handleCopyMove(brw, r) case "LOCK": - status, err = h.handleLock(w, r) + status, err = h.handleLock(brw, r) case "UNLOCK": - status, err = h.handleUnlock(w, r) + status, err = h.handleUnlock(brw, r) case "PROPFIND": - status, err = h.handlePropfind(w, r) + status, err = h.handlePropfind(brw, r) case "PROPPATCH": - status, err = h.handleProppatch(w, r) + status, err = h.handleProppatch(brw, r) } } @@ -77,6 +80,8 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if status != http.StatusNoContent { w.Write([]byte(StatusText(status))) } + } else if useBufferedWriter { + brw.WriteToResponse(w) } if h.Logger != nil && err != nil { h.Logger(r, err)