diff --git a/.gitignore b/.gitignore index dfdf34dc73..7d8b044509 100644 --- a/.gitignore +++ b/.gitignore @@ -60,3 +60,4 @@ values-dev.yaml *.tsbuildinfo +.coda/ diff --git a/backend/application/base/appinfra/app_infra.go b/backend/application/base/appinfra/app_infra.go index 651e2a8e1e..18e05d1356 100644 --- a/backend/application/base/appinfra/app_infra.go +++ b/backend/application/base/appinfra/app_infra.go @@ -51,7 +51,6 @@ import ( "github.com/coze-dev/coze-studio/backend/infra/contract/messages2query" "github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr" "github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis" - "github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/direct" "github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/sandbox" builtinNL2SQL "github.com/coze-dev/coze-studio/backend/infra/impl/document/nl2sql/builtin" "github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/ppocr" @@ -346,40 +345,43 @@ func initKnowledgeEventBusProducer() (eventbus.Producer, error) { } func initCodeRunner() coderunner.Runner { - switch typ := os.Getenv(consts.CodeRunnerType); typ { - case "sandbox": - getAndSplit := func(key string) []string { - v := os.Getenv(key) - if v == "" { - return nil - } - return strings.Split(v, ",") - } - config := &sandbox.Config{ - AllowEnv: getAndSplit(consts.CodeRunnerAllowEnv), - AllowRead: getAndSplit(consts.CodeRunnerAllowRead), - AllowWrite: getAndSplit(consts.CodeRunnerAllowWrite), - AllowNet: getAndSplit(consts.CodeRunnerAllowNet), - AllowRun: getAndSplit(consts.CodeRunnerAllowRun), - AllowFFI: getAndSplit(consts.CodeRunnerAllowFFI), - NodeModulesDir: os.Getenv(consts.CodeRunnerNodeModulesDir), - TimeoutSeconds: 0, - MemoryLimitMB: 0, - } - if f, err := strconv.ParseFloat(os.Getenv(consts.CodeRunnerTimeoutSeconds), 64); err == nil { - config.TimeoutSeconds = f - } else { - config.TimeoutSeconds = 60.0 - } - if mem, err := strconv.ParseInt(os.Getenv(consts.CodeRunnerMemoryLimitMB), 10, 64); err == nil { - config.MemoryLimitMB = mem - } else { - config.MemoryLimitMB = 100 - } - return sandbox.NewRunner(config) - default: - return direct.NewRunner() + // 为了安全考虑,移除不安全的direct runner,强制使用sandbox + getAndSplit := func(key string) []string { + v := os.Getenv(key) + if v == "" { + return nil + } + return strings.Split(v, ",") + } + + // 使用安全的默认配置 + config := &sandbox.Config{ + AllowEnv: getAndSplit(consts.CodeRunnerAllowEnv), // 默认为空,禁止环境变量访问 + AllowRead: getAndSplit(consts.CodeRunnerAllowRead), // 默认为空,禁止文件读取 + AllowWrite: getAndSplit(consts.CodeRunnerAllowWrite), // 默认为空,禁止文件写入 + AllowNet: getAndSplit(consts.CodeRunnerAllowNet), // 默认为空,禁止网络访问 + AllowRun: getAndSplit(consts.CodeRunnerAllowRun), // 默认为空,禁止运行外部程序 + AllowFFI: getAndSplit(consts.CodeRunnerAllowFFI), // 默认为空,禁止FFI调用 + NodeModulesDir: os.Getenv(consts.CodeRunnerNodeModulesDir), + TimeoutSeconds: 0, + MemoryLimitMB: 0, + } + + // 设置安全的超时时间,最大30秒 + if f, err := strconv.ParseFloat(os.Getenv(consts.CodeRunnerTimeoutSeconds), 64); err == nil && f > 0 && f <= 30 { + config.TimeoutSeconds = f + } else { + config.TimeoutSeconds = 30.0 // 默认30秒超时 } + + // 设置安全的内存限制,最大100MB + if mem, err := strconv.ParseInt(os.Getenv(consts.CodeRunnerMemoryLimitMB), 10, 64); err == nil && mem > 0 && mem <= 100 { + config.MemoryLimitMB = mem + } else { + config.MemoryLimitMB = 100 // 默认100MB内存限制 + } + + return sandbox.NewRunner(config) } func initOCR() ocr.OCR { @@ -798,4 +800,4 @@ func getEmbedding(ctx context.Context) (embedding.Embedder, error) { } return emb, nil -} +} \ No newline at end of file diff --git a/backend/infra/impl/document/parser/builtin/parse_markdown.go b/backend/infra/impl/document/parser/builtin/parse_markdown.go index f734d5e8e9..c6e0b743c2 100644 --- a/backend/infra/impl/document/parser/builtin/parse_markdown.go +++ b/backend/infra/impl/document/parser/builtin/parse_markdown.go @@ -21,6 +21,7 @@ import ( "encoding/base64" "fmt" "io" + "net" "net/http" "net/url" "strings" @@ -47,6 +48,7 @@ func ParseMarkdown(config *contract.Config, storage storage.Storage, ocr ocr.OCR return nil, err } + node := mdParser.Parse(text.NewReader(b)) cs := config.ChunkingStrategy ps := config.ParsingStrategy @@ -101,13 +103,118 @@ func ParseMarkdown(config *contract.Config, storage storage.Storage, ocr ocr.OCR return text } + // validateImageURL 验证图片URL的安全性 + validateImageURL := func(urlString string) error { + parsedURL, err := url.Parse(urlString) + if err != nil { + return err + } + + // 只允许HTTP/HTTPS + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return fmt.Errorf("unsupported scheme: %s", parsedURL.Scheme) + } + + // 检查域名白名单 + allowedDomains := []string{ + "images.unsplash.com", + "cdn.example.com", + "github.com", + "githubusercontent.com", + // 可以根据需要添加其他受信任的域名 + } + + hostname := parsedURL.Hostname() + for _, domain := range allowedDomains { + if hostname == domain || strings.HasSuffix(hostname, "."+domain) { + return nil + } + } + + return fmt.Errorf("domain not allowed: %s", hostname) + } + + // isPrivateIPAddress 检查IP地址是否为私有地址 + isPrivateIPAddress := func(ip net.IP) bool { + // 检查私有IP范围 + privateRanges := []struct { + cidr string + }{ + {"10.0.0.0/8"}, + {"172.16.0.0/12"}, + {"192.168.0.0/16"}, + {"127.0.0.0/8"}, + {"169.254.0.0/16"}, // 链路本地地址 + {"::1/128"}, // IPv6 loopback + {"fc00::/7"}, // IPv6 私有地址 + } + + for _, r := range privateRanges { + _, cidr, _ := net.ParseCIDR(r.cidr) + if cidr.Contains(ip) { + return true + } + } + + return false + } + + // isPrivateIP 检查是否为私有IP地址 + isPrivateIP := func(host string) bool { + ip := net.ParseIP(host) + if ip == nil { + // 可能是域名,需要解析 + ips, err := net.LookupIP(host) + if err != nil { + return true // 解析失败,拒绝访问 + } + + // 检查所有解析的IP + for _, resolvedIP := range ips { + if isPrivateIPAddress(resolvedIP) { + return true + } + } + return false + } + + return isPrivateIPAddress(ip) + } + downloadImage := func(ctx context.Context, url string) ([]byte, error) { - client := &http.Client{Timeout: 5 * time.Second} + // URL验证 + if err := validateImageURL(url); err != nil { + return nil, fmt.Errorf("invalid URL: %w", err) + } + + // 使用安全的HTTP客户端 + client := &http.Client{ + Timeout: 5 * time.Second, + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + // 禁止访问私有IP + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + if isPrivateIP(host) { + return nil, fmt.Errorf("access to private IP denied: %s", host) + } + + return (&net.Dialer{}).DialContext(ctx, network, addr) + }, + }, + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, fmt.Errorf("failed to create HTTP request: %w", err) } - + + // 添加安全头 + req.Header.Set("User-Agent", "CozeStudio/1.0") + resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("failed to download image: %w", err) @@ -118,7 +225,11 @@ func ParseMarkdown(config *contract.Config, storage storage.Storage, ocr ocr.OCR return nil, fmt.Errorf("failed to download image, status code: %d", resp.StatusCode) } - data, err := io.ReadAll(resp.Body) + // 限制响应大小 + const maxImageSize = 10 * 1024 * 1024 // 10MB + limitedReader := io.LimitReader(resp.Body, maxImageSize) + + data, err := io.ReadAll(limitedReader) if err != nil { return nil, fmt.Errorf("failed to read image content: %w", err) } diff --git a/backend/infra/impl/rdb/mysql.go b/backend/infra/impl/rdb/mysql.go index edc60331ed..0cb91e1402 100644 --- a/backend/infra/impl/rdb/mysql.go +++ b/backend/infra/impl/rdb/mysql.go @@ -646,15 +646,15 @@ func (m *mysqlService) ExecuteSQL(ctx context.Context, req *rdb.ExecuteSQLReques var processedParams []interface{} var err error - // Handle SQLType: if raw, do not process params + // 禁用原始SQL执行以防止SQL注入攻击 if req.SQLType == entity2.SQLType_Raw { - processedSQL = req.SQL - processedParams = nil - } else { - processedSQL, processedParams, err = m.processSliceParams(req.SQL, req.Params) - if err != nil { - return nil, fmt.Errorf("failed to process parameters: %v", err) - } + return nil, fmt.Errorf("raw SQL execution is not allowed for security reasons") + } + + // 强制使用参数化查询 + processedSQL, processedParams, err = m.processSliceParams(req.SQL, req.Params) + if err != nil { + return nil, fmt.Errorf("failed to process parameters: %v", err) } operation, err := sqlparser.NewSQLParser().GetSQLOperation(processedSQL) @@ -1011,4 +1011,4 @@ func (m *mysqlService) buildNestedConditions(condition *rdb.ComplexCondition) (s return whereClause.String(), values, nil } return "", values, nil -} +} \ No newline at end of file