From 1ae1737c1c4be766771e6f1c8f84c176acf61174 Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Fri, 17 May 2024 17:28:57 +0200 Subject: [PATCH] added RemoveFileProvider --- fileprovider.go | 71 ++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 55 insertions(+), 16 deletions(-) diff --git a/fileprovider.go b/fileprovider.go index 7833f84..3f4f6ff 100644 --- a/fileprovider.go +++ b/fileprovider.go @@ -26,12 +26,12 @@ type dirFileProvider struct { dir fs.File } -func (d dirFileProvider) HasFile(filename string) (bool, error) { - return d.dir.Join(filename).Exists(), nil +func (p dirFileProvider) HasFile(filename string) (bool, error) { + return p.dir.Join(filename).Exists(), nil } -func (d dirFileProvider) ListFiles(ctx context.Context) (filenames []string, err error) { - err = d.dir.ListDirContext(ctx, func(file fs.File) error { +func (p dirFileProvider) ListFiles(ctx context.Context) (filenames []string, err error) { + err = p.dir.ListDirContext(ctx, func(file fs.File) error { filenames = append(filenames, file.Name()) return nil }) @@ -42,8 +42,8 @@ func (d dirFileProvider) ListFiles(ctx context.Context) (filenames []string, err return filenames, nil } -func (d dirFileProvider) ReadFile(ctx context.Context, filename string) ([]byte, error) { - return d.dir.Join(filename).ReadAllContext(ctx) +func (p dirFileProvider) ReadFile(ctx context.Context, filename string) ([]byte, error) { + return p.dir.Join(filename).ReadAllContext(ctx) } /////////////////////////////////////////////////////////////////////////////// @@ -52,7 +52,7 @@ func (d dirFileProvider) ReadFile(ctx context.Context, filename string) ([]byte, // ExtFileProvider returns a FileProvider that extends a base FileProvider // with additional files that will be returned before the files of the base FileProvider. func ExtFileProvider(base FileProvider, extFiles ...fs.FileReader) FileProvider { - return &extFileProvider{base, extFiles} + return extFileProvider{base, extFiles} } type extFileProvider struct { @@ -60,21 +60,21 @@ type extFileProvider struct { extFiles []fs.FileReader } -func (e *extFileProvider) HasFile(filename string) (bool, error) { - for _, f := range e.extFiles { +func (p extFileProvider) HasFile(filename string) (bool, error) { + for _, f := range p.extFiles { if f.Name() == filename { return true, nil } } - return e.base.HasFile(filename) + return p.base.HasFile(filename) } -func (d extFileProvider) ListFiles(ctx context.Context) (filenames []string, err error) { - filenames, err = d.base.ListFiles(ctx) +func (p extFileProvider) ListFiles(ctx context.Context) (filenames []string, err error) { + filenames, err = p.base.ListFiles(ctx) if err != nil { return nil, err } - for _, f := range d.extFiles { + for _, f := range p.extFiles { if !slices.Contains(filenames, f.Name()) { filenames = append(filenames, f.Name()) } @@ -83,13 +83,52 @@ func (d extFileProvider) ListFiles(ctx context.Context) (filenames []string, err return filenames, nil } -func (d extFileProvider) ReadFile(ctx context.Context, filename string) ([]byte, error) { - for _, f := range d.extFiles { +func (p extFileProvider) ReadFile(ctx context.Context, filename string) ([]byte, error) { + for _, f := range p.extFiles { if f.Name() == filename { return f.ReadAllContext(ctx) } } - return d.base.ReadFile(ctx, filename) + return p.base.ReadFile(ctx, filename) +} + +/////////////////////////////////////////////////////////////////////////////// +// RemoveFileProvider + +// RemoveFileProvider returns a FileProvider that wraps a base FileProvider +// and does not return files with the passed removeFilenames. +func RemoveFileProvider(base FileProvider, removeFilenames ...string) FileProvider { + return &removeFileProvider{base, removeFilenames} +} + +type removeFileProvider struct { + base FileProvider + remove []string +} + +func (p *removeFileProvider) HasFile(filename string) (bool, error) { + if slices.Contains(p.remove, filename) { + return false, nil + } + return p.base.HasFile(filename) +} + +func (p removeFileProvider) ListFiles(ctx context.Context) (filenames []string, err error) { + filenames, err = p.base.ListFiles(ctx) + if err != nil { + return nil, err + } + filenames = slices.DeleteFunc(filenames, func(filename string) bool { + return slices.Contains(p.remove, filename) + }) + return filenames, nil +} + +func (p removeFileProvider) ReadFile(ctx context.Context, filename string) ([]byte, error) { + if slices.Contains(p.remove, filename) { + return nil, fs.NewErrPathDoesNotExist(filename) + } + return p.base.ReadFile(ctx, filename) } // func MemFileProvider(file fs.MemFile) FileProvider {