From 59a4b7d25513787639c80dd7db6240feac4b51ca Mon Sep 17 00:00:00 2001 From: wener Date: Fri, 17 Nov 2023 12:37:49 +0800 Subject: [PATCH] feat: allowed add audit query params to downstream proxy --- code/contexts/keys.go | 107 +++++++++++++++++++++++++++ code/handlers/event_common_action.go | 3 +- code/handlers/event_msg_action.go | 9 ++- code/handlers/handler.go | 25 +++++++ code/initialization/config.go | 2 + code/services/openai/gpt3.go | 22 ++++-- code/services/openai/gpt3_test.go | 2 +- 7 files changed, 161 insertions(+), 9 deletions(-) create mode 100644 code/contexts/keys.go diff --git a/code/contexts/keys.go b/code/contexts/keys.go new file mode 100644 index 00000000..3cfcf2af --- /dev/null +++ b/code/contexts/keys.go @@ -0,0 +1,107 @@ +package contexts + +import ( + "context" + "fmt" + "net/url" + "reflect" +) + +type ChatContext struct { + Tenant string + ChatID string + ChatType string + MessageID string + MessageType string + SenderOpenID string + SenderType string + SenderUnionID string + SenderUserID string + SessionID string +} + +func (cc *ChatContext) Encode() string { + v := url.Values{} + v.Set("feishu.tenant", cc.Tenant) + v.Set("feishu.session_id", cc.SessionID) + v.Set("feishu.chat_id", cc.ChatID) + v.Set("feishu.chat_type", cc.ChatType) + v.Set("feishu.message_id", cc.MessageID) + v.Set("feishu.message_type", cc.MessageType) + v.Set("feishu.sender_user_id", cc.SenderUserID) + v.Set("feishu.sender_union_id", cc.SenderUnionID) + v.Set("feishu.sender_open_id", cc.SenderOpenID) + v.Set("feishu.sender_type", cc.SenderType) + for k, vv := range v { + if len(vv) == 0 { + delete(v, k) + } + } + return v.Encode() +} + +var ChatContextKey = CreateContextKey[*ChatContext]() + +type ContextKey[T any] interface { + Value(ctx context.Context) (T, bool) + Get(ctx context.Context) T + Must(ctx context.Context) T + WithValue(ctx context.Context, val T) context.Context +} + +type key[T any] struct { + opts CreateContextKeyOptions[T] +} + +func (k key[T]) Value(ctx context.Context) (T, bool) { + o, ok := ctx.Value(k.opts.key).(T) + return o, ok +} + +func (k key[T]) Get(ctx context.Context) T { + o, _ := ctx.Value(k.opts.key).(T) + return o +} + +func (k key[T]) Must(ctx context.Context) T { + o, ok := ctx.Value(k.opts.key).(T) + if !ok { + panic(fmt.Errorf("%s not found in context", k.String())) + } + return o +} + +func (k key[T]) WithValue(ctx context.Context, val T) context.Context { + return context.WithValue(ctx, k.opts.key, val) +} + +func (k key[T]) String() string { + name := k.opts.Name + if name != "" { + name = "@" + name + } + return fmt.Sprintf("ContextKey(%s%s)", reflect.TypeOf(new(T)).Elem().String(), name) +} + +var _ ContextKey[string] = (*key[string])(nil) + +type CreateContextKeyOptions[T any] struct { + Name string + key any +} + +func CreateContextKey[T any](opts ...CreateContextKeyOptions[T]) ContextKey[T] { + var opt CreateContextKeyOptions[T] + if len(opts) > 0 { + // reduce + for _, o := range opts { + opt = o + } + } + if opt.Name != "" { + opt.key = opt.Name + } else { + opt.key = reflect.TypeOf(new(T)).Elem() + } + return &key[T]{opts: opt} +} diff --git a/code/handlers/event_common_action.go b/code/handlers/event_common_action.go index d6827b52..a28bf120 100644 --- a/code/handlers/event_common_action.go +++ b/code/handlers/event_common_action.go @@ -3,7 +3,7 @@ package handlers import ( "context" "fmt" - + "start-feishubot/contexts" "start-feishubot/initialization" "start-feishubot/services/openai" "start-feishubot/utils" @@ -21,6 +21,7 @@ type MsgInfo struct { imageKey string sessionId *string mention []*larkim.MentionEvent + Context *contexts.ChatContext } type ActionInfo struct { handler *MessageHandler diff --git a/code/handlers/event_msg_action.go b/code/handlers/event_msg_action.go index 6625a375..e59099a1 100644 --- a/code/handlers/event_msg_action.go +++ b/code/handlers/event_msg_action.go @@ -1,9 +1,11 @@ package handlers import ( + "context" "encoding/json" "fmt" "log" + "start-feishubot/contexts" "strings" "time" @@ -37,11 +39,14 @@ func (*MessageAction) Execute(a *ActionInfo) bool { Role: "user", Content: a.info.qParsed, }) + ctx := context.Background() + ctx = contexts.ChatContextKey.WithValue(ctx, a.info.Context) + //fmt.Println("msg", msg) //logger.Debug("msg", msg) // get ai mode as temperature aiMode := a.handler.sessionCache.GetAIMode(*a.info.sessionId) - completions, err := a.handler.gpt.Completions(msg, aiMode) + completions, err := a.handler.gpt.Completions(ctx, msg, aiMode) if err != nil { replyMsg(*a.ctx, fmt.Sprintf( "🤖️:消息机器人摆烂了,请稍后再试~\n错误信息: %v", err), a.info.msgId) @@ -70,7 +75,7 @@ func (*MessageAction) Execute(a *ActionInfo) bool { return true } -//判断msg中的是否包含system role +// 判断msg中的是否包含system role func hasSystemRole(msg []openai.Messages) bool { for _, m := range msg { if m.Role == "system" { diff --git a/code/handlers/handler.go b/code/handlers/handler.go index 3458c99d..9438727a 100644 --- a/code/handlers/handler.go +++ b/code/handlers/handler.go @@ -4,6 +4,8 @@ import ( "context" "fmt" larkcore "github.com/larksuite/oapi-sdk-go/v3/core" + "log" + "start-feishubot/contexts" "start-feishubot/logger" "strings" @@ -85,6 +87,27 @@ func (m MessageHandler) msgReceivedHandler(ctx context.Context, event *larkim.P2 sessionId: sessionId, mention: mention, } + { + get := func(s *string) string { + if s == nil { + return "" + } + return *s + } + cc := &contexts.ChatContext{ + MessageID: get(msgInfo.msgId), + MessageType: msgInfo.msgType, + ChatID: get(msgInfo.chatId), + ChatType: get(event.Event.Message.ChatType), + SessionID: get(msgInfo.sessionId), + SenderUserID: get(event.Event.Sender.SenderId.UserId), + SenderOpenID: get(event.Event.Sender.SenderId.OpenId), + SenderUnionID: get(event.Event.Sender.SenderId.UnionId), + SenderType: get(event.Event.Sender.SenderType), + Tenant: get(event.Event.Sender.TenantKey), + } + msgInfo.Context = cc + } data := &ActionInfo{ ctx: &ctx, handler: &m, @@ -127,6 +150,8 @@ func (m MessageHandler) judgeIfMentionMe(mention []*larkim. if len(mention) != 1 { return false } + // for simple debugging, find a way to pass the info to endpoint + log.Printf("mention: name=%v key=%v id.userid=%v id.openid=%v", *mention[0].Name, *mention[0].Key, *mention[0].Id.UserId, *mention[0].Id.OpenId) return *mention[0].Name == m.config.FeishuBotName } diff --git a/code/initialization/config.go b/code/initialization/config.go index a33c760a..b8a978e7 100644 --- a/code/initialization/config.go +++ b/code/initialization/config.go @@ -38,6 +38,7 @@ type Config struct { AzureResourceName string AzureOpenaiToken string StreamMode bool + AuditQueryParams bool } var ( @@ -89,6 +90,7 @@ func LoadConfig(cfg string) *Config { AzureResourceName: getViperStringValue("AZURE_RESOURCE_NAME", ""), AzureOpenaiToken: getViperStringValue("AZURE_OPENAI_TOKEN", ""), StreamMode: getViperBoolValue("STREAM_MODE", false), + AuditQueryParams: getViperBoolValue("AUDIT_QUERY_PARAMS", false), } return config diff --git a/code/services/openai/gpt3.go b/code/services/openai/gpt3.go index 87425bbb..1b0a991a 100644 --- a/code/services/openai/gpt3.go +++ b/code/services/openai/gpt3.go @@ -1,7 +1,10 @@ package openai import ( + "context" "errors" + "start-feishubot/contexts" + "start-feishubot/initialization" "start-feishubot/logger" "strings" @@ -68,7 +71,7 @@ func (msg *Messages) CalculateTokenLength() int { return tokenizer.MustCalToken(text) } -func (gpt *ChatGPT) Completions(msg []Messages, aiMode AIMode) (resp Messages, +func (gpt *ChatGPT) Completions(ctx context.Context, msg []Messages, aiMode AIMode) (resp Messages, err error) { requestBody := ChatGPTRequestBody{ Model: gpt.Model, @@ -80,14 +83,23 @@ func (gpt *ChatGPT) Completions(msg []Messages, aiMode AIMode) (resp Messages, PresencePenalty: 0, } gptResponseBody := &ChatGPTResponseBody{} - url := gpt.FullUrl("chat/completions") + fullUrl := gpt.FullUrl("chat/completions") + if initialization.GetConfig().AuditQueryParams { + cc := contexts.ChatContextKey.Must(ctx) + if cc != nil { + encode := cc.Encode() + if encode != "" { + fullUrl = fullUrl + "?" + encode + } + } + } //fmt.Println(url) - logger.Debug(url) + logger.Debug(fullUrl) logger.Debug("request body ", requestBody) - if url == "" { + if fullUrl == "" { return resp, errors.New("无法获取openai请求地址") } - err = gpt.sendRequestWithBodyType(url, "POST", jsonBody, requestBody, gptResponseBody) + err = gpt.sendRequestWithBodyType(fullUrl, "POST", jsonBody, requestBody, gptResponseBody) if err == nil && len(gptResponseBody.Choices) > 0 { resp = gptResponseBody.Choices[0].Message } else { diff --git a/code/services/openai/gpt3_test.go b/code/services/openai/gpt3_test.go index fb442972..bc64e066 100644 --- a/code/services/openai/gpt3_test.go +++ b/code/services/openai/gpt3_test.go @@ -16,7 +16,7 @@ func TestCompletions(t *testing.T) { {Role: "user", Content: "翻译这段话: The assistant messages help store prior responses. They can also be written by a developer to help give examples of desired behavior."}, } gpt := NewChatGPT(*config) - resp, err := gpt.Completions(msgs, Balance) + resp, err := gpt.Completions(nil, msgs, Balance) if err != nil { t.Errorf("TestCompletions failed with error: %v", err) }