Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,4 @@ values-dev.yaml

*.tsbuildinfo

.coda/
72 changes: 37 additions & 35 deletions backend/application/base/appinfra/app_infra.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ import (
"github.com/coze-dev/coze-studio/backend/infra/contract/messages2query"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/direct"
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/sandbox"
builtinNL2SQL "github.com/coze-dev/coze-studio/backend/infra/impl/document/nl2sql/builtin"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/ppocr"
Expand Down Expand Up @@ -346,40 +345,43 @@ func initKnowledgeEventBusProducer() (eventbus.Producer, error) {
}

func initCodeRunner() coderunner.Runner {
switch typ := os.Getenv(consts.CodeRunnerType); typ {
case "sandbox":
getAndSplit := func(key string) []string {
v := os.Getenv(key)
if v == "" {
return nil
}
return strings.Split(v, ",")
}
config := &sandbox.Config{
AllowEnv: getAndSplit(consts.CodeRunnerAllowEnv),
AllowRead: getAndSplit(consts.CodeRunnerAllowRead),
AllowWrite: getAndSplit(consts.CodeRunnerAllowWrite),
AllowNet: getAndSplit(consts.CodeRunnerAllowNet),
AllowRun: getAndSplit(consts.CodeRunnerAllowRun),
AllowFFI: getAndSplit(consts.CodeRunnerAllowFFI),
NodeModulesDir: os.Getenv(consts.CodeRunnerNodeModulesDir),
TimeoutSeconds: 0,
MemoryLimitMB: 0,
}
if f, err := strconv.ParseFloat(os.Getenv(consts.CodeRunnerTimeoutSeconds), 64); err == nil {
config.TimeoutSeconds = f
} else {
config.TimeoutSeconds = 60.0
}
if mem, err := strconv.ParseInt(os.Getenv(consts.CodeRunnerMemoryLimitMB), 10, 64); err == nil {
config.MemoryLimitMB = mem
} else {
config.MemoryLimitMB = 100
}
return sandbox.NewRunner(config)
default:
return direct.NewRunner()
// 为了安全考虑,移除不安全的direct runner,强制使用sandbox
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里我们不需要移除本地运行方法吧?不然对现有的部分用户来说可能是破坏性的。
只需要将新用户的代码默认运行方式设置为 sandbox 即可,历史用户升级不受影响。

getAndSplit := func(key string) []string {
v := os.Getenv(key)
if v == "" {
return nil
}
return strings.Split(v, ",")
}

// 使用安全的默认配置
config := &sandbox.Config{
AllowEnv: getAndSplit(consts.CodeRunnerAllowEnv), // 默认为空,禁止环境变量访问
AllowRead: getAndSplit(consts.CodeRunnerAllowRead), // 默认为空,禁止文件读取
AllowWrite: getAndSplit(consts.CodeRunnerAllowWrite), // 默认为空,禁止文件写入
AllowNet: getAndSplit(consts.CodeRunnerAllowNet), // 默认为空,禁止网络访问
AllowRun: getAndSplit(consts.CodeRunnerAllowRun), // 默认为空,禁止运行外部程序
AllowFFI: getAndSplit(consts.CodeRunnerAllowFFI), // 默认为空,禁止FFI调用
NodeModulesDir: os.Getenv(consts.CodeRunnerNodeModulesDir),
TimeoutSeconds: 0,
MemoryLimitMB: 0,
}

// 设置安全的超时时间,最大30秒
if f, err := strconv.ParseFloat(os.Getenv(consts.CodeRunnerTimeoutSeconds), 64); err == nil && f > 0 && f <= 30 {
config.TimeoutSeconds = f
} else {
config.TimeoutSeconds = 30.0 // 默认30秒超时
}

// 设置安全的内存限制,最大100MB
if mem, err := strconv.ParseInt(os.Getenv(consts.CodeRunnerMemoryLimitMB), 10, 64); err == nil && mem > 0 && mem <= 100 {
config.MemoryLimitMB = mem
} else {
config.MemoryLimitMB = 100 // 默认100MB内存限制
}

return sandbox.NewRunner(config)
}

func initOCR() ocr.OCR {
Expand Down Expand Up @@ -798,4 +800,4 @@ func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
}

return emb, nil
}
}
117 changes: 114 additions & 3 deletions backend/infra/impl/document/parser/builtin/parse_markdown.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
Expand All @@ -47,6 +48,7 @@ func ParseMarkdown(config *contract.Config, storage storage.Storage, ocr ocr.OCR
return nil, err
}


node := mdParser.Parse(text.NewReader(b))
cs := config.ChunkingStrategy
ps := config.ParsingStrategy
Expand Down Expand Up @@ -101,13 +103,118 @@ func ParseMarkdown(config *contract.Config, storage storage.Storage, ocr ocr.OCR
return text
}

// validateImageURL 验证图片URL的安全性
validateImageURL := func(urlString string) error {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个貌似不在漏洞修复范围内,暂时不用改?

parsedURL, err := url.Parse(urlString)
if err != nil {
return err
}

// 只允许HTTP/HTTPS
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return fmt.Errorf("unsupported scheme: %s", parsedURL.Scheme)
}

// 检查域名白名单
allowedDomains := []string{
"images.unsplash.com",
"cdn.example.com",
"github.com",
"githubusercontent.com",
// 可以根据需要添加其他受信任的域名
}

hostname := parsedURL.Hostname()
for _, domain := range allowedDomains {
if hostname == domain || strings.HasSuffix(hostname, "."+domain) {
return nil
}
}

return fmt.Errorf("domain not allowed: %s", hostname)
}

// isPrivateIPAddress 检查IP地址是否为私有地址
isPrivateIPAddress := func(ip net.IP) bool {
// 检查私有IP范围
privateRanges := []struct {
cidr string
}{
{"10.0.0.0/8"},
{"172.16.0.0/12"},
{"192.168.0.0/16"},
{"127.0.0.0/8"},
{"169.254.0.0/16"}, // 链路本地地址
{"::1/128"}, // IPv6 loopback
{"fc00::/7"}, // IPv6 私有地址
}

for _, r := range privateRanges {
_, cidr, _ := net.ParseCIDR(r.cidr)
if cidr.Contains(ip) {
return true
}
}

return false
}

// isPrivateIP 检查是否为私有IP地址
isPrivateIP := func(host string) bool {
ip := net.ParseIP(host)
if ip == nil {
// 可能是域名,需要解析
ips, err := net.LookupIP(host)
if err != nil {
return true // 解析失败,拒绝访问
}

// 检查所有解析的IP
for _, resolvedIP := range ips {
if isPrivateIPAddress(resolvedIP) {
return true
}
}
return false
}

return isPrivateIPAddress(ip)
}

downloadImage := func(ctx context.Context, url string) ([]byte, error) {
client := &http.Client{Timeout: 5 * time.Second}
// URL验证
if err := validateImageURL(url); err != nil {
return nil, fmt.Errorf("invalid URL: %w", err)
}

// 使用安全的HTTP客户端
client := &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
// 禁止访问私有IP
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}

if isPrivateIP(host) {
return nil, fmt.Errorf("access to private IP denied: %s", host)
}

return (&net.Dialer{}).DialContext(ctx, network, addr)
},
},
}

req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}


// 添加安全头
req.Header.Set("User-Agent", "CozeStudio/1.0")

resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to download image: %w", err)
Expand All @@ -118,7 +225,11 @@ func ParseMarkdown(config *contract.Config, storage storage.Storage, ocr ocr.OCR
return nil, fmt.Errorf("failed to download image, status code: %d", resp.StatusCode)
}

data, err := io.ReadAll(resp.Body)
// 限制响应大小
const maxImageSize = 10 * 1024 * 1024 // 10MB
limitedReader := io.LimitReader(resp.Body, maxImageSize)

data, err := io.ReadAll(limitedReader)
if err != nil {
return nil, fmt.Errorf("failed to read image content: %w", err)
}
Expand Down
18 changes: 9 additions & 9 deletions backend/infra/impl/rdb/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -646,15 +646,15 @@ func (m *mysqlService) ExecuteSQL(ctx context.Context, req *rdb.ExecuteSQLReques
var processedParams []interface{}
var err error

// Handle SQLType: if raw, do not process params
// 禁用原始SQL执行以防止SQL注入攻击
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不能直接禁用,workflow 数据库节点功能依赖这里,通过 env 加个配置让用户判断是否运行直接运行 sql 这样可能好点?

if req.SQLType == entity2.SQLType_Raw {
processedSQL = req.SQL
processedParams = nil
} else {
processedSQL, processedParams, err = m.processSliceParams(req.SQL, req.Params)
if err != nil {
return nil, fmt.Errorf("failed to process parameters: %v", err)
}
return nil, fmt.Errorf("raw SQL execution is not allowed for security reasons")
}

// 强制使用参数化查询
processedSQL, processedParams, err = m.processSliceParams(req.SQL, req.Params)
if err != nil {
return nil, fmt.Errorf("failed to process parameters: %v", err)
}

operation, err := sqlparser.NewSQLParser().GetSQLOperation(processedSQL)
Expand Down Expand Up @@ -1011,4 +1011,4 @@ func (m *mysqlService) buildNestedConditions(condition *rdb.ComplexCondition) (s
return whereClause.String(), values, nil
}
return "", values, nil
}
}
Loading