mcp: add image url support (#1)
* refactor: remove redundant validation in service and clean up code - Remove validatePublishRequest function as gin binding already handles validation - Remove unused errors import - Simplify PublishContent method by relying on gin's built-in validation - Add comprehensive image processing support with URL download capability * 添加 MCP 说明文档,完善教程
This commit is contained in:
148
pkg/downloader/images.go
Normal file
148
pkg/downloader/images.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/h2non/filetype"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// ImageDownloader 图片下载器
|
||||
type ImageDownloader struct {
|
||||
savePath string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewImageDownloader 创建图片下载器
|
||||
func NewImageDownloader(savePath string) *ImageDownloader {
|
||||
// 确保保存目录存在
|
||||
if err := os.MkdirAll(savePath, 0755); err != nil {
|
||||
panic(fmt.Sprintf("failed to create save path: %v", err))
|
||||
}
|
||||
|
||||
return &ImageDownloader{
|
||||
savePath: savePath,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// DownloadImage 下载图片
|
||||
// 返回本地文件路径
|
||||
func (d *ImageDownloader) DownloadImage(imageURL string) (string, error) {
|
||||
// 验证URL格式
|
||||
if !d.isValidImageURL(imageURL) {
|
||||
return "", errors.New("invalid image URL format")
|
||||
}
|
||||
|
||||
// 下载图片数据
|
||||
resp, err := d.httpClient.Get(imageURL)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to download image")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("download failed with status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 读取图片数据
|
||||
imageData, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to read image data")
|
||||
}
|
||||
|
||||
// 检测图片格式
|
||||
kind, err := filetype.Match(imageData)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to detect file type")
|
||||
}
|
||||
|
||||
if !filetype.IsImage(imageData) {
|
||||
return "", errors.New("downloaded file is not a valid image")
|
||||
}
|
||||
|
||||
// 生成唯一文件名
|
||||
fileName := d.generateFileName(imageURL, kind.Extension)
|
||||
filePath := filepath.Join(d.savePath, fileName)
|
||||
|
||||
// 如果文件已存在,直接返回路径
|
||||
if _, err := os.Stat(filePath); err == nil {
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
// 保存到文件
|
||||
if err := os.WriteFile(filePath, imageData, 0644); err != nil {
|
||||
return "", errors.Wrap(err, "failed to save image")
|
||||
}
|
||||
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
// DownloadImages 批量下载图片
|
||||
func (d *ImageDownloader) DownloadImages(imageURLs []string) ([]string, error) {
|
||||
var localPaths []string
|
||||
var errs []error
|
||||
|
||||
for _, imageURL := range imageURLs {
|
||||
localPath, err := d.DownloadImage(imageURL)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to download %s: %w", imageURL, err))
|
||||
continue
|
||||
}
|
||||
localPaths = append(localPaths, localPath)
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return localPaths, fmt.Errorf("download errors occurred: %v", errs)
|
||||
}
|
||||
|
||||
return localPaths, nil
|
||||
}
|
||||
|
||||
// isValidImageURL 检查是否为有效的图片URL
|
||||
func (d *ImageDownloader) isValidImageURL(rawURL string) bool {
|
||||
// 检查是否以http/https开头
|
||||
if !strings.HasPrefix(strings.ToLower(rawURL), "http://") &&
|
||||
!strings.HasPrefix(strings.ToLower(rawURL), "https://") {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查URL格式
|
||||
parsedURL, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return parsedURL.Scheme != "" && parsedURL.Host != ""
|
||||
}
|
||||
|
||||
// generateFileName 生成唯一的文件名
|
||||
func (d *ImageDownloader) generateFileName(imageURL, extension string) string {
|
||||
// 使用URL的SHA256哈希作为文件名,确保唯一性
|
||||
hash := sha256.Sum256([]byte(imageURL))
|
||||
hashStr := fmt.Sprintf("%x", hash)
|
||||
|
||||
// 取前16位哈希值作为文件名
|
||||
shortHash := hashStr[:16]
|
||||
|
||||
// 添加时间戳确保更好的唯一性
|
||||
timestamp := time.Now().Unix()
|
||||
|
||||
return fmt.Sprintf("img_%s_%d.%s", shortHash, timestamp, extension)
|
||||
}
|
||||
|
||||
// IsImageURL 判断字符串是否为图片URL
|
||||
func IsImageURL(path string) bool {
|
||||
return strings.HasPrefix(strings.ToLower(path), "http://") ||
|
||||
strings.HasPrefix(strings.ToLower(path), "https://")
|
||||
}
|
||||
102
pkg/downloader/images_test.go
Normal file
102
pkg/downloader/images_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsImageURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{"https://example.com/image.jpg", true},
|
||||
{"http://example.com/image.png", true},
|
||||
{"HTTPS://example.com/image.gif", true},
|
||||
{"/local/path/image.jpg", false},
|
||||
{"./relative/path/image.png", false},
|
||||
{"image.jpg", false},
|
||||
{"ftp://example.com/image.jpg", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result := IsImageURL(test.input)
|
||||
if result != test.expected {
|
||||
t.Errorf("IsImageURL(%q) = %v, expected %v", test.input, result, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewImageDownloader(t *testing.T) {
|
||||
tempDir := os.TempDir()
|
||||
testPath := filepath.Join(tempDir, "test_downloader")
|
||||
defer os.RemoveAll(testPath)
|
||||
|
||||
downloader := NewImageDownloader(testPath)
|
||||
|
||||
if downloader == nil {
|
||||
t.Fatal("NewImageDownloader returned nil")
|
||||
}
|
||||
|
||||
if downloader.savePath != testPath {
|
||||
t.Errorf("savePath = %q, expected %q", downloader.savePath, testPath)
|
||||
}
|
||||
|
||||
// 验证目录是否创建
|
||||
if _, err := os.Stat(testPath); os.IsNotExist(err) {
|
||||
t.Errorf("save path directory was not created: %s", testPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageDownloader_isValidImageURL(t *testing.T) {
|
||||
downloader := NewImageDownloader(os.TempDir())
|
||||
|
||||
tests := []struct {
|
||||
url string
|
||||
expected bool
|
||||
}{
|
||||
{"https://example.com/image.jpg", true},
|
||||
{"http://example.com/image.png", true},
|
||||
{"https://", false},
|
||||
{"http://", false},
|
||||
{"invalid-url", false},
|
||||
{"ftp://example.com/image.jpg", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result := downloader.isValidImageURL(test.url)
|
||||
if result != test.expected {
|
||||
t.Errorf("isValidImageURL(%q) = %v, expected %v", test.url, result, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageDownloader_generateFileName(t *testing.T) {
|
||||
downloader := NewImageDownloader(os.TempDir())
|
||||
|
||||
url := "https://example.com/image.jpg"
|
||||
extension := "jpg"
|
||||
|
||||
fileName1 := downloader.generateFileName(url, extension)
|
||||
|
||||
// 文件名应该包含扩展名
|
||||
if filepath.Ext(fileName1) != "."+extension {
|
||||
t.Errorf("fileName should end with .%s, got %s", extension, fileName1)
|
||||
}
|
||||
|
||||
// 文件名应该包含img_前缀
|
||||
if !strings.HasPrefix(filepath.Base(fileName1), "img_") {
|
||||
t.Errorf("fileName should start with img_, got %s", fileName1)
|
||||
}
|
||||
|
||||
// 不同URL应该生成不同的文件名
|
||||
url2 := "https://example.com/different.jpg"
|
||||
fileName2 := downloader.generateFileName(url2, extension)
|
||||
if fileName1 == fileName2 {
|
||||
t.Errorf("different URLs should generate different file names")
|
||||
}
|
||||
}
|
||||
53
pkg/downloader/processor.go
Normal file
53
pkg/downloader/processor.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/xpzouying/xiaohongshu-mcp/configs"
|
||||
)
|
||||
|
||||
// ImageProcessor 图片处理器
|
||||
type ImageProcessor struct {
|
||||
downloader *ImageDownloader
|
||||
}
|
||||
|
||||
// NewImageProcessor 创建图片处理器
|
||||
func NewImageProcessor() *ImageProcessor {
|
||||
return &ImageProcessor{
|
||||
downloader: NewImageDownloader(configs.GetImagesPath()),
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessImages 处理图片列表,返回本地文件路径
|
||||
// 支持两种输入格式:
|
||||
// 1. URL格式 (http/https开头) - 自动下载到本地
|
||||
// 2. 本地文件路径 - 直接使用
|
||||
func (p *ImageProcessor) ProcessImages(images []string) ([]string, error) {
|
||||
var localPaths []string
|
||||
var urlsToDownload []string
|
||||
|
||||
// 分离URL和本地路径
|
||||
for _, image := range images {
|
||||
if IsImageURL(image) {
|
||||
urlsToDownload = append(urlsToDownload, image)
|
||||
} else {
|
||||
// 本地路径直接添加
|
||||
localPaths = append(localPaths, image)
|
||||
}
|
||||
}
|
||||
|
||||
// 批量下载URL图片
|
||||
if len(urlsToDownload) > 0 {
|
||||
downloadedPaths, err := p.downloader.DownloadImages(urlsToDownload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to download images: %w", err)
|
||||
}
|
||||
localPaths = append(localPaths, downloadedPaths...)
|
||||
}
|
||||
|
||||
if len(localPaths) == 0 {
|
||||
return nil, fmt.Errorf("no valid images found")
|
||||
}
|
||||
|
||||
return localPaths, nil
|
||||
}
|
||||
Reference in New Issue
Block a user