基于 Go 1.25 的新特性,我将之前的工具包进行了全面重构和升级。

新版主要引入了 os.Root 保障文件操作安全,使用迭代器和泛型提升代码的现代感与复用性,并整合了 testing/fstest 优化测试体验。

  • 安全升级:os.Root 防御路径穿越:使用 Go 1.24 引入的 os.Root 防范路径穿越攻击,安全地处理文件操作,同时也保留了不依赖此功能的传统方法
  • 迭代器模式:利用 Go 1.23 引入的 iter.Seq,为 ListFiles 等方法设计了迭代器版本,实现内存友好的流式处理
  • 泛型与类型安全:在迭代器和 BatchProcessor 中应用泛型,在编译时提供更强的类型安全保障。
  • 测试辅助:集成了 testing/fstest.MapFS,方便构建基于内存的模拟文件系统进行单元测试。
  • 结构化管理与并发:引入 FileUtil 结构体封装基础路径,统一管理操作上下文。BatchProcessor 结构体则用于对目录下的文件进行并发的批量处理。

一、新版 fileutil 包代码

// Package fileutil 提供基于 Go 1.24+ 新特性的文件/目录操作工具。
package fileutil

import (
	"context"
	"errors"
	"fmt"
	"io"
	"io/fs"
	"iter"
	"os"
	"path/filepath"
	"runtime"
	"strings"
	"sync"
)

// ============================================
// 1. 基础工具与常量
// ============================================

// Exists 判断路径是否存在。
func Exists(path string) bool {
	_, err := os.Stat(path)
	return !errors.Is(err, fs.ErrNotExist)
}

// IsDir 判断路径是否为目录。
func IsDir(path string) (bool, error) {
	info, err := os.Stat(path)
	if err != nil {
		return false, err
	}
	return info.IsDir(), nil
}

// IsFile 判断路径是否为普通文件。
func IsFile(path string) (bool, error) {
	info, err := os.Stat(path)
	if err != nil {
		return false, err
	}
	return !info.IsDir(), nil
}

// EnsureDir 确保目录存在,否则递归创建。
func EnsureDir(path string) error {
	if Exists(path) {
		isDir, err := IsDir(path)
		if err != nil {
			return err
		}
		if !isDir {
			return fmt.Errorf("path exists but is not a directory: %s", path)
		}
		return nil
	}
	return os.MkdirAll(path, 0755)
}

// ============================================
// 2. 核心结构:FileUtil (基于 os.Root)
// ============================================

// FileUtil 文件操作工具的核心结构体,封装了 os.Root。
type FileUtil struct {
	root *os.Root
	path string // 基础路径,用于日志或显示
}

// NewFileUtil 创建一个新的 FileUtil 实例。
// basePath 是所有后续操作受限的根目录。
func NewFileUtil(basePath string) (*FileUtil, error) {
	root, err := os.OpenRoot(basePath)
	if err != nil {
		return nil, fmt.Errorf("failed to open root %q: %w", basePath, err)
	}
	return &FileUtil{root: root, path: basePath}, nil
}

// Close 释放 FileUtil 持有的资源。
func (fu *FileUtil) Close() error {
	if fu.root != nil {
		return fu.root.Close()
	}
	return nil
}

// BasePath 返回 FileUtil 的根路径。
func (fu *FileUtil) BasePath() string {
	return fu.path
}

// ============================================
// 3. 安全的文件/目录操作 (使用 os.Root)
// ============================================

// SafeReadFile 安全地读取文件内容。
func (fu *FileUtil) SafeReadFile(relPath string) ([]byte, error) {
	return fu.root.ReadFile(relPath)
}

// SafeWriteFile 安全地将数据写入文件。
func (fu *FileUtil) SafeWriteFile(relPath string, data []byte, perm fs.FileMode) error {
	return fu.root.WriteFile(relPath, data, perm)
}

// SafeMkdirAll 安全地创建多级目录。
func (fu *FileUtil) SafeMkdirAll(relPath string, perm fs.FileMode) error {
	return fu.root.MkdirAll(relPath, perm)
}

// SafeRemove 安全地移除文件或空目录。
func (fu *FileUtil) SafeRemove(relPath string) error {
	return fu.root.Remove(relPath)
}

// ============================================
// 4. 传统文件/目录操作 (不使用 os.Root)
// ============================================

// CopyFile 复制文件从 src 到 dst。
func CopyFile(src, dst string) error {
	srcFile, err := os.Open(src)
	if err != nil {
		return fmt.Errorf("open source file: %w", err)
	}
	defer srcFile.Close()

	srcInfo, err := srcFile.Stat()
	if err != nil {
		return fmt.Errorf("stat source file: %w", err)
	}

	dstFile, err := os.Create(dst)
	if err != nil {
		return fmt.Errorf("create destination file: %w", err)
	}
	defer dstFile.Close()

	if _, err := io.Copy(dstFile, srcFile); err != nil {
		return fmt.Errorf("copy content: %w", err)
	}

	if err := dstFile.Sync(); err != nil {
		return fmt.Errorf("sync destination file: %w", err)
	}
	return os.Chmod(dst, srcInfo.Mode())
}

// MoveFile 移动文件。
func MoveFile(src, dst string) error {
	err := os.Rename(src, dst)
	if err == nil {
		return nil
	}
	// 跨文件系统降级处理
	if err := CopyFile(src, dst); err != nil {
		return fmt.Errorf("copy file during move: %w", err)
	}
	return os.Remove(src)
}

// CopyDir 递归复制整个目录。
func CopyDir(src, dst string) error {
	srcInfo, err := os.Stat(src)
	if err != nil {
		return fmt.Errorf("stat source directory: %w", err)
	}
	if !srcInfo.IsDir() {
		return fmt.Errorf("source is not a directory: %s", src)
	}

	if err := os.MkdirAll(dst, srcInfo.Mode()); err != nil {
		return fmt.Errorf("create destination directory: %w", err)
	}

	entries, err := os.ReadDir(src)
	if err != nil {
		return fmt.Errorf("read source directory: %w", err)
	}

	for _, entry := range entries {
		srcPath := filepath.Join(src, entry.Name())
		dstPath := filepath.Join(dst, entry.Name())
		if entry.IsDir() {
			if err := CopyDir(srcPath, dstPath); err != nil {
				return err
			}
		} else {
			if err := CopyFile(srcPath, dstPath); err != nil {
				return err
			}
		}
	}
	return nil
}

// DeleteDir 删除整个目录 (类似 rm -rf)。
func DeleteDir(path string) error {
	if !Exists(path) {
		return nil
	}
	return os.RemoveAll(path)
}

// ReadString 快速将文件内容读取为字符串。
func ReadString(path string) (string, error) {
	data, err := os.ReadFile(path)
	if err != nil {
		return "", err
	}
	return string(data), nil
}

// WriteString 快速将字符串写入文件。
func WriteString(path, content string) error {
	return os.WriteFile(path, []byte(content), 0644)
}

// ============================================
// 5. 迭代器与流式处理
// ============================================

// Lines 返回一个迭代器,用于逐行读取文件内容。
func Lines(path string) (iter.Seq[string], func() error) {
	file, err := os.Open(path)
	if err != nil {
		return func(yield func(string) bool) {
			return
		}, func() error { return err }
	}

	return func(yield func(string) bool) {
			defer file.Close()
			scanner := bufio.NewScanner(file)
			for scanner.Scan() {
				if !yield(scanner.Text()) {
					return
				}
			}
		}, func() error {
			defer file.Close()
			scanner := bufio.NewScanner(file)
			for scanner.Scan() {
				// 消耗迭代器以检查错误
			}
			return scanner.Err()
		}
}

// ListFilesIter 返回一个迭代器,流式地遍历目录下的文件。
func ListFilesIter(dir string, extensions []string) (iter.Seq[string], func() error) {
	var walkErr error
	seq := func(yield func(string) bool) {
		walkErr = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error {
			if err != nil {
				return err
			}
			if d.IsDir() {
				return nil
			}
			if len(extensions) > 0 {
				ext := strings.ToLower(filepath.Ext(path))
				matched := false
				for _, e := range extensions {
					if strings.ToLower(e) == ext {
						matched = true
						break
					}
				}
				if !matched {
					return nil
				}
			}
			if !yield(path) {
				return filepath.SkipAll
			}
			return nil
		})
	}
	return seq, func() error { return walkErr }
}

// DirSize 计算目录总大小。
func DirSize(path string) (int64, error) {
	var size int64
	err := filepath.WalkDir(path, func(_ string, d fs.DirEntry, err error) error {
		if err != nil {
			return err
		}
		if !d.IsDir() {
			info, err := d.Info()
			if err != nil {
				return err
			}
			size += info.Size()
		}
		return nil
	})
	return size, err
}

// ============================================
// 6. 原子写入与批量处理
// ============================================

// SafeWrite 原子性地将数据写入文件。
func SafeWrite(path string, data []byte, perm fs.FileMode) error {
	dir, filename := filepath.Split(path)
	tmpFile, err := os.CreateTemp(dir, ".tmp-*"+filename)
	if err != nil {
		return err
	}
	tmpName := tmpFile.Name()
	defer os.Remove(tmpName)

	if _, err := tmpFile.Write(data); err != nil {
		tmpFile.Close()
		return err
	}
	if err := tmpFile.Sync(); err != nil {
		tmpFile.Close()
		return err
	}
	if err := tmpFile.Close(); err != nil {
		return err
	}
	if err := os.Chmod(tmpName, perm); err != nil {
		return err
	}
	return os.Rename(tmpName, path)
}

// BatchProcessor 用于并发处理目录下的一系列文件。
type BatchProcessor[T any] struct {
	workers int
}

// NewBatchProcessor 创建一个新的批量处理器。
func NewBatchProcessor[T any](workers int) *BatchProcessor[T] {
	if workers <= 0 {
		workers = runtime.NumCPU()
	}
	return &BatchProcessor[T]{workers: workers}
}

// Process 对目录 dir 下匹配扩展名的所有文件,并发执行处理函数。
func (bp *BatchProcessor[T]) Process(ctx context.Context, dir string, extensions []string, handler func(string) (T, error)) ([]T, []error) {
	fileSeq, errFn := ListFilesIter(dir, extensions)
	var wg sync.WaitGroup
	fileCh := make(chan string)
	resultCh := make(chan T)
	errCh := make(chan error)

	// 启动 workers
	for i := 0; i < bp.workers; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()
			for file := range fileCh {
				select {
				case <-ctx.Done():
					return
				default:
					if res, err := handler(file); err != nil {
						errCh <- fmt.Errorf("file %s: %w", file, err)
					} else {
						resultCh <- res
					}
				}
			}
		}()
	}

	// 发送文件路径
	go func() {
		for file := range fileSeq {
			select {
			case <-ctx.Done():
				break
			case fileCh <- file:
			}
		}
		close(fileCh)
	}()

	// 等待并收集结果
	go func() {
		wg.Wait()
		close(resultCh)
		close(errCh)
	}()

	var results []T
	var errs []error
	for {
		select {
		case res, ok := <-resultCh:
			if !ok {
				resultCh = nil
			} else {
				results = append(results, res)
			}
		case err, ok := <-errCh:
			if !ok {
				errCh = nil
			} else {
				errs = append(errs, err)
			}
		}
		if resultCh == nil && errCh == nil {
			break
		}
	}

	if walkErr := errFn(); walkErr != nil {
		errs = append(errs, fmt.Errorf("walk error: %w", walkErr))
	}

	return results, errs
}

// ============================================
// 7. 测试辅助 (testing/fstest)
// ============================================

// NewTestFS 创建一个用于测试的内存文件系统。
func NewTestFS() fstest.MapFS {
	return make(fstest.MapFS)
}

// AddTestFile 向测试文件系统添加一个文件。
func AddTestFile(fsys fstest.MapFS, path, content string) {
	fsys[path] = &fstest.MapFile{Data: []byte(content)}
}

二、使用示例

func main() {
	// 1. 使用 os.Root 进行安全操作
	fu, err := fileutil.NewFileUtil("./data")
	if err != nil {
		log.Fatal(err)
	}
	defer fu.Close()

	if err := fu.SafeMkdirAll("logs", 0755); err != nil {
		log.Fatal(err)
	}
	
	if err := fu.SafeWriteFile("config.json", []byte(`{"key": "value"}`), 0644); err != nil {
		log.Fatal(err)
	}
	
	data, err := fu.SafeReadFile("config.json")
	if err != nil {
		log.Fatal(err)
	}
	fmt.Println(string(data))

	// 2. 使用迭代器逐行读取
	linesSeq, errFn := fileutil.Lines("./data/large.log")
	var count int
	for line := range linesSeq {
		_ = line
		count++
	}
	if err := errFn(); err != nil {
		log.Fatal(err)
	}
	fmt.Printf("Total lines: %d\n", count)

	// 3. 批量处理文件
	bp := fileutil.NewBatchProcessor[string](4)
	results, errs := bp.Process(context.Background(), "./data", []string{".log"}, func(path string) (string, error) {
		return fileutil.ReadString(path)
	})
	for _, err := range errs {
		log.Printf("Error: %v", err)
	}
	fmt.Printf("Processed %d files\n", len(results))
}

三、关键优化点详解

  1. os.Root:安全的文件操作基础
    通过 os.OpenRoot(basePath) 创建一个安全的根,后续所有操作(如 ReadFile)都会自动被限制在这个根目录内。这从根本上杜绝了路径穿越攻击,是处理来自用户输入的文件路径时最推荐的做法
  2. 迭代器:内存友好的流式处理
    对于大文件或包含海量文件的目录,一次性加载所有内容可能导致内存溢出。迭代器模式允许你以流式方式逐行或逐个文件地处理数据,例如 Lines 函数逐行返回内容,内存占用极低。
  3. 泛型批处理器:复用并发处理逻辑
    BatchProcessor[T] 是一个通用的批处理器,它封装了并发遍历、处理文件并收集结果的复杂逻辑。你只需指定返回类型 T 和核心处理函数,而无需关心 goroutine 管理、channel 同步等细节。
  4. testing/fstest:轻量级单元测试
    传统的文件系统单元测试往往依赖真实文件,过程繁琐且易出错。testing/fstest 允许在内存中轻松构建一个模拟的文件系统,让测试变得快速、可靠且易于管理。

希望这个融合了 Go 新特性的工具包能让你的开发工作更高效、更安全。