From c9e7d75fec4da5529301e083f4ccdc327f4cfe73 Mon Sep 17 00:00:00 2001 From: xixingya <2679431923@qq.com> Date: Thu, 23 May 2024 18:54:40 +0800 Subject: [PATCH] add stable-code-3b local model support (#30) * add stable-code-3b local model support * add stable-code-3b local model support * add stable-code-3b local model support * add stable-code-3b local model support * fix code struct add chat model todo --- README.md | 13 +++++++- localModel.go | 1 + main.go | 82 +++++++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 89 insertions(+), 7 deletions(-) create mode 100644 localModel.go diff --git a/README.md b/README.md index ee7f073..b5e42d7 100644 --- a/README.md +++ b/README.md @@ -34,13 +34,15 @@ "codex_api_key": "sk-xxx", "codex_api_organization": "", "codex_api_project": "", + "code_instruct_model": "gpt-3.5-turbo-instruct", "chat_api_base": "https://api-proxy.oaipro.com/v1", "chat_api_key": "sk-xxx", "chat_api_organization": "", "chat_api_project": "", "chat_max_tokens": 4096, "chat_model_default": "gpt-4o", - "chat_model_map": {} + "chat_model_map": {}, + "auth_token": "" } ``` @@ -52,6 +54,15 @@ 可以通过 `OVERRIDE_` + 大写配置项作为环境变量,可以覆盖 `config.json` 中的值。例如:`OVERRIDE_CODEX_API_KEY=sk-xxxx` +### 本地大模型设置 +1. 安装ollama +2. ollama run stable-code:code (这个模型较小,大部分显卡都能跑) + 或者你的显卡比较高安装这个:ollama run stable-code:3b-code-fp16 +3. 修改config.json里面的codex_api_base为http://localhost:11434/v1/chat +4. 修改code_instruct_model为你的模型名称,stable-code:code或者stable-code:3b-code-fp16 +4. 剩下的就按照正常流程走即可。 +5. 如果调不通,请确认http://localhost:11434/v1/chat可用。 + ### 重要说明 `codex_max_tokens` 工作并不完美,已经移除。**JetBrains IDE 完美工作**,`VSCode` 需要执行以下脚本Patch之: diff --git a/localModel.go b/localModel.go new file mode 100644 index 0000000..06ab7d0 --- /dev/null +++ b/localModel.go @@ -0,0 +1 @@ +package main diff --git a/main.go b/main.go index 1a36a42..3bae8fe 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "github.com/gin-gonic/gin" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -20,7 +21,9 @@ import ( "time" ) -const InstructModel = "gpt-3.5-turbo-instruct" +const DefaultInstructModel = "gpt-3.5-turbo-instruct" + +const StableCodeModelPrefix = "stable-code" type config struct { Bind string `json:"bind"` @@ -30,6 +33,7 @@ type config struct { CodexApiKey string `json:"codex_api_key"` CodexApiOrganization string `json:"codex_api_organization"` CodexApiProject string `json:"codex_api_project"` + CodeInstructModel string `json:"code_instruct_model"` ChatApiBase string `json:"chat_api_base"` ChatApiKey string `json:"chat_api_key"` ChatApiOrganization string `json:"chat_api_organization"` @@ -38,6 +42,7 @@ type config struct { ChatModelDefault string `json:"chat_model_default"` ChatModelMap map[string]string `json:"chat_model_map"` ChatLocale string `json:"chat_locale"` + AuthToken string `json:"auth_token"` } func readConfig() *config { @@ -88,6 +93,9 @@ func readConfig() *config { } } } + if _cfg.CodeInstructModel == "" { + _cfg.CodeInstructModel = DefaultInstructModel + } return _cfg } @@ -150,10 +158,31 @@ func NewProxyService(cfg *config) (*ProxyService, error) { client: client, }, nil } +func AuthMiddleware(authToken string) gin.HandlerFunc { + return func(c *gin.Context) { + token := c.Param("token") + if token != authToken { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + c.Abort() + return + } + c.Next() + } +} func (s *ProxyService) InitRoutes(e *gin.Engine) { - e.POST("/v1/chat/completions", s.completions) - e.POST("/v1/engines/copilot-codex/completions", s.codeCompletions) + authToken := s.cfg.AuthToken // replace with your dynamic value as needed + if authToken != "" { + // 鉴权 + v1 := e.Group("/:token/v1/", AuthMiddleware(authToken)) + { + v1.POST("/chat/completions", s.completions) + v1.POST("/engines/copilot-codex/completions", s.codeCompletions) + } + } else { + e.POST("/v1/chat/completions", s.completions) + e.POST("/v1/engines/copilot-codex/completions", s.codeCompletions) + } } func (s *ProxyService) completions(c *gin.Context) { @@ -254,13 +283,12 @@ func (s *ProxyService) codeCompletions(c *gin.Context) { return } - body, _ = sjson.DeleteBytes(body, "extra") - body, _ = sjson.DeleteBytes(body, "nwo") - body, _ = sjson.SetBytes(body, "model", InstructModel) + body = ConstructRequestBody(body, s.cfg) proxyUrl := s.cfg.CodexApiBase + "/completions" req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body))) if nil != err { + // abortCodex(c, http.StatusInternalServerError) return } @@ -305,6 +333,47 @@ func (s *ProxyService) codeCompletions(c *gin.Context) { _, _ = io.Copy(c.Writer, resp.Body) } +func ConstructRequestBody(body []byte, cfg *config) []byte { + body, _ = sjson.DeleteBytes(body, "extra") + body, _ = sjson.DeleteBytes(body, "nwo") + body, _ = sjson.SetBytes(body, "model", cfg.CodeInstructModel) + if strings.Contains(cfg.CodeInstructModel, StableCodeModelPrefix) { + return constructWithStableCodeModel(body) + } + if strings.HasSuffix(cfg.ChatApiBase, "chat") { + // @Todo constructWithChatModel + // 如果code base以chat结尾则构建chatModel,暂时没有好的prompt + } + return body +} + +func constructWithStableCodeModel(body []byte) []byte { + suffix := gjson.GetBytes(body, "suffix") + prompt := gjson.GetBytes(body, "prompt") + content := fmt.Sprintf("%s%s", prompt, suffix) + + // 创建新的 JSON 对象并添加到 body 中 + messages := []map[string]string{ + { + "role": "user", + "content": content, + }, + } + return constructWithChatModel(body, messages) +} + +func constructWithChatModel(body []byte, messages interface{}) []byte { + + body, _ = sjson.SetBytes(body, "messages", messages) + + // fmt.Printf("Request Body: %s\n", body) + // 2. 将转义的字符替换回原来的字符 + jsonStr := string(body) + jsonStr = strings.ReplaceAll(jsonStr, "\\u003c", "<") + jsonStr = strings.ReplaceAll(jsonStr, "\\u003e", ">") + return []byte(jsonStr) +} + func main() { cfg := readConfig() @@ -324,4 +393,5 @@ func main() { log.Fatal(err) return } + }