From 3e37b44c40f8ab8ba10daf28f5ba2257c4cbca79 Mon Sep 17 00:00:00 2001 From: liuzhifei <2679431923@qq.com> Date: Thu, 23 May 2024 14:25:08 +0800 Subject: [PATCH] fix code struct add chat model todo --- localModel.go | 1 + main.go | 72 ++++++++++++++++++++++++++++++++++----------------- 2 files changed, 49 insertions(+), 24 deletions(-) create mode 100644 localModel.go diff --git a/localModel.go b/localModel.go new file mode 100644 index 0000000..06ab7d0 --- /dev/null +++ b/localModel.go @@ -0,0 +1 @@ +package main diff --git a/main.go b/main.go index c5fd75e..3bae8fe 100644 --- a/main.go +++ b/main.go @@ -23,6 +23,8 @@ import ( const DefaultInstructModel = "gpt-3.5-turbo-instruct" +const StableCodeModelPrefix = "stable-code" + type config struct { Bind string `json:"bind"` ProxyUrl string `json:"proxy_url"` @@ -168,8 +170,8 @@ func AuthMiddleware(authToken string) gin.HandlerFunc { } } -func (s *ProxyService) InitRoutes(e *gin.Engine, cfg *config) { - authToken := cfg.AuthToken // replace with your dynamic value as needed +func (s *ProxyService) InitRoutes(e *gin.Engine) { + authToken := s.cfg.AuthToken // replace with your dynamic value as needed if authToken != "" { // 鉴权 v1 := e.Group("/:token/v1/", AuthMiddleware(authToken)) @@ -281,29 +283,10 @@ func (s *ProxyService) codeCompletions(c *gin.Context) { return } - body, _ = sjson.DeleteBytes(body, "extra") - body, _ = sjson.DeleteBytes(body, "nwo") - suffix := gjson.GetBytes(body, "suffix") - prompt := gjson.GetBytes(body, "prompt") - content := fmt.Sprintf("%s%s", prompt, suffix) + body = ConstructRequestBody(body, s.cfg) - // 创建新的 JSON 对象并添加到 body 中 - messages := []map[string]string{ - { - "role": "user", - "content": content, - }, - } - body, _ = sjson.SetBytes(body, "messages", messages) - body, _ = sjson.SetBytes(body, "model", s.cfg.CodeInstructModel) - - // fmt.Printf("Request Body: %s\n", body) - // 2. 将转义的字符替换回原来的字符 - jsonStr := string(body) - jsonStr = strings.ReplaceAll(jsonStr, "\\u003c", "<") - jsonStr = strings.ReplaceAll(jsonStr, "\\u003e", ">") proxyUrl := s.cfg.CodexApiBase + "/completions" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer([]byte(jsonStr)))) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body))) if nil != err { // abortCodex(c, http.StatusInternalServerError) @@ -350,6 +333,47 @@ func (s *ProxyService) codeCompletions(c *gin.Context) { _, _ = io.Copy(c.Writer, resp.Body) } +func ConstructRequestBody(body []byte, cfg *config) []byte { + body, _ = sjson.DeleteBytes(body, "extra") + body, _ = sjson.DeleteBytes(body, "nwo") + body, _ = sjson.SetBytes(body, "model", cfg.CodeInstructModel) + if strings.Contains(cfg.CodeInstructModel, StableCodeModelPrefix) { + return constructWithStableCodeModel(body) + } + if strings.HasSuffix(cfg.ChatApiBase, "chat") { + // @Todo constructWithChatModel + // 如果code base以chat结尾则构建chatModel,暂时没有好的prompt + } + return body +} + +func constructWithStableCodeModel(body []byte) []byte { + suffix := gjson.GetBytes(body, "suffix") + prompt := gjson.GetBytes(body, "prompt") + content := fmt.Sprintf("%s%s", prompt, suffix) + + // 创建新的 JSON 对象并添加到 body 中 + messages := []map[string]string{ + { + "role": "user", + "content": content, + }, + } + return constructWithChatModel(body, messages) +} + +func constructWithChatModel(body []byte, messages interface{}) []byte { + + body, _ = sjson.SetBytes(body, "messages", messages) + + // fmt.Printf("Request Body: %s\n", body) + // 2. 将转义的字符替换回原来的字符 + jsonStr := string(body) + jsonStr = strings.ReplaceAll(jsonStr, "\\u003c", "<") + jsonStr = strings.ReplaceAll(jsonStr, "\\u003e", ">") + return []byte(jsonStr) +} + func main() { cfg := readConfig() @@ -362,7 +386,7 @@ func main() { return } - proxyService.InitRoutes(r, cfg) + proxyService.InitRoutes(r) err = r.Run(cfg.Bind) if nil != err {