From 4749a0945fabc7d825c87d9a2782aaca194e2b41 Mon Sep 17 00:00:00 2001 From: liuzhifei <2679431923@qq.com> Date: Wed, 22 May 2024 15:53:10 +0800 Subject: [PATCH] add stable-code-3b local model support --- main.go | 54 ++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 48 insertions(+), 6 deletions(-) diff --git a/main.go b/main.go index 1a36a42..17a335d 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "github.com/gin-gonic/gin" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -20,7 +21,7 @@ import ( "time" ) -const InstructModel = "gpt-3.5-turbo-instruct" +const InstructModel = "stable-code:3b-code-fp16" type config struct { Bind string `json:"bind"` @@ -38,6 +39,7 @@ type config struct { ChatModelDefault string `json:"chat_model_default"` ChatModelMap map[string]string `json:"chat_model_map"` ChatLocale string `json:"chat_locale"` + AuthToken string `json:"auth_token"` } func readConfig() *config { @@ -150,10 +152,31 @@ func NewProxyService(cfg *config) (*ProxyService, error) { client: client, }, nil } +func AuthMiddleware(authToken string) gin.HandlerFunc { + return func(c *gin.Context) { + token := c.Param("token") + if token != authToken { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + c.Abort() + return + } + c.Next() + } +} -func (s *ProxyService) InitRoutes(e *gin.Engine) { - e.POST("/v1/chat/completions", s.completions) - e.POST("/v1/engines/copilot-codex/completions", s.codeCompletions) +func (s *ProxyService) InitRoutes(e *gin.Engine, cfg *config) { + authToken := cfg.AuthToken // replace with your dynamic value as needed + if authToken != "" { + // 鉴权 + v1 := e.Group("/:token/v1/", AuthMiddleware(authToken)) + { + v1.POST("/chat/completions", s.completions) + v1.POST("/engines/copilot-codex/completions", s.codeCompletions) + } + } else { + e.POST("/v1/chat/completions", s.completions) + e.POST("/v1/engines/copilot-codex/completions", s.codeCompletions) + } } func (s *ProxyService) completions(c *gin.Context) { @@ -256,11 +279,29 @@ func (s *ProxyService) codeCompletions(c *gin.Context) { body, _ = sjson.DeleteBytes(body, "extra") body, _ = sjson.DeleteBytes(body, "nwo") + suffix := gjson.GetBytes(body, "suffix") + prompt := gjson.GetBytes(body, "prompt") + content := fmt.Sprintf("%s%s", prompt, suffix) + + // 创建新的 JSON 对象并添加到 body 中 + messages := []map[string]string{ + { + "role": "user", + "content": content, + }, + } + body, _ = sjson.SetBytes(body, "messages", messages) body, _ = sjson.SetBytes(body, "model", InstructModel) + // fmt.Printf("Request Body: %s\n", body) + // 2. 将转义的字符替换回原来的字符 + jsonStr := string(body) + jsonStr = strings.ReplaceAll(jsonStr, "\\u003c", "<") + jsonStr = strings.ReplaceAll(jsonStr, "\\u003e", ">") proxyUrl := s.cfg.CodexApiBase + "/completions" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body))) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer([]byte(jsonStr)))) if nil != err { + // abortCodex(c, http.StatusInternalServerError) return } @@ -317,11 +358,12 @@ func main() { return } - proxyService.InitRoutes(r) + proxyService.InitRoutes(r, cfg) err = r.Run(cfg.Bind) if nil != err { log.Fatal(err) return } + }