diff --git a/README.md b/README.md index 3a52857..d8f36f4 100644 --- a/README.md +++ b/README.md @@ -34,13 +34,15 @@ "codex_api_key": "sk-xxx", "codex_api_organization": "", "codex_api_project": "", - "codex_max_tokens": 4093, + "code_instruct_model": "gpt-3.5-turbo-instruct", "chat_api_base": "https://api-proxy.oaipro.com/v1", "chat_api_key": "sk-xxx", "chat_api_organization": "", "chat_api_project": "", + "chat_max_tokens": 4096, "chat_model_default": "gpt-4o", - "chat_model_map": {} + "chat_model_map": {}, + "auth_token": "" } ``` @@ -48,10 +50,59 @@ `chat_model_map` 是个模型映射的字典。会将请求的模型映射到你想要的,如果不存在映射,则使用 `chat_model_default` 。 -`code_max_tokens` 可以设置为你希望的最大Token数,你设置的时候最好知道自己在做什么。 +`chat_max_tokens` 可以设置为你希望的最大Token数,你设置的时候最好知道自己在做什么。`gpt-4o` 输出最大为 `4096` 可以通过 `OVERRIDE_` + 大写配置项作为环境变量,可以覆盖 `config.json` 中的值。例如:`OVERRIDE_CODEX_API_KEY=sk-xxxx` +### 本地大模型设置 +1. 安装ollama +2. ollama run stable-code:code (这个模型较小,大部分显卡都能跑) + 或者你的显卡比较高安装这个:ollama run stable-code:3b-code-fp16 +3. 修改config.json里面的codex_api_base为http://localhost:11434/v1/chat +4. 修改code_instruct_model为你的模型名称,stable-code:code或者stable-code:3b-code-fp16 +4. 剩下的就按照正常流程走即可。 +5. 如果调不通,请确认http://localhost:11434/v1/chat可用。 + +### 重要说明 +`codex_max_tokens` 工作并不完美,已经移除。**JetBrains IDE 完美工作**,`VSCode` 需要执行以下脚本Patch之: + +* macOS `sed -i '' -E 's/\.maxPromptCompletionTokens\(([a-zA-Z0-9_]+),([0-9]+)\)/.maxPromptCompletionTokens(\1,2048)/' ~/.vscode/extensions/github.copilot-*/dist/extension.js` +* Linux `sed -E 's/\.maxPromptCompletionTokens\(([a-zA-Z0-9_]+),([0-9]+)\)/.maxPromptCompletionTokens(\1,2048)/' ~/.vscode/extensions/github.copilot-*/dist/extension.js` +* Windows 可以用如下的python脚本进行替换 +* 因为是Patch,所以:**Copilot每次升级都要执行一次**。 +* 具体原因是客户端需要根据 `max_tokens` 精密计算prompt,后台删减会有问题。 + +``` +# github copilot extention replace script +import re +import glob +import os + +file_paths = glob.glob(os.getenv("USERPROFILE") + r'\.vscode\extensions\github.copilot-*\dist\extension.js') +if file_paths == list(): + print("no copilot extension found") + exit() + +pattern = re.compile(r'\.maxPromptCompletionTokens\(([a-zA-Z0-9_]+),([0-9]+)\)') +replacement = r'.maxPromptCompletionTokens(\1,2048)' + +for file_path in file_paths: + with open(file_path, 'r', encoding="utf-8") as file: + content = file.read() + + new_content = pattern.sub(replacement, content) + if new_content == content: + print("no match found in " + file_path) + continue + else: + print("replaced " + file_path) + + with open(file_path, 'w', encoding='utf-8') as file: + file.write(new_content) + +print("replace finish") +``` + ### 其他说明 1. 理论上,Chat 部分可以使用 `chat2api` ,而 Codex 代码生成部分则不太适合使用 `chat2api` 。 2. 代码生成部分做过延时生成和客户端 Cancel 处理,很有效节省你的Token。 diff --git a/config.json.example b/config.json.example index 0d315b1..26c77b0 100644 --- a/config.json.example +++ b/config.json.example @@ -6,11 +6,12 @@ "codex_api_key": "sk-xxx", "codex_api_organization": "", "codex_api_project": "", - "codex_max_tokens": 2048, "chat_api_base": "https://api-proxy.oaipro.com/v1", "chat_api_key": "sk-xxx", "chat_api_organization": "", "chat_api_project": "", + "chat_max_tokens": 4096, "chat_model_default": "gpt-4o", - "chat_model_map": {} + "chat_model_map": {}, + "chat_locale": "zh_CN" } \ No newline at end of file diff --git a/go.mod b/go.mod index 34c8f1d..5ef356a 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,6 @@ toolchain go1.21.4 require ( github.com/gin-gonic/gin v1.10.0 - github.com/linux-do/tiktoken-go v0.7.0 github.com/tidwall/gjson v1.17.1 github.com/tidwall/sjson v1.2.5 golang.org/x/net v0.25.0 @@ -17,7 +16,6 @@ require ( github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect - github.com/dlclark/regexp2 v1.11.0 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect @@ -25,7 +23,6 @@ require ( github.com/go-playground/validator/v10 v10.20.0 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-cmp v0.5.9 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/kr/pretty v0.3.0 // indirect diff --git a/go.sum b/go.sum index e05fea7..ebce207 100644 --- a/go.sum +++ b/go.sum @@ -10,8 +10,6 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= -github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= @@ -31,8 +29,6 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= @@ -49,8 +45,6 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= -github.com/linux-do/tiktoken-go v0.7.0 h1:Kcm/miJ5gp77srtF8GQWnfq7W9kTaXEuHZg/g9IVEu8= -github.com/linux-do/tiktoken-go v0.7.0/go.mod h1:9Vkdtp0ngi4USmrdSx984iuIQ5IMr0hnUdz4jZZTJb8= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= diff --git a/localModel.go b/localModel.go new file mode 100644 index 0000000..06ab7d0 --- /dev/null +++ b/localModel.go @@ -0,0 +1 @@ +package main diff --git a/main.go b/main.go index 8104c02..bebff28 100644 --- a/main.go +++ b/main.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" - "github.com/linux-do/tiktoken-go" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "golang.org/x/net/http2" @@ -23,7 +22,9 @@ import ( "time" ) -const INSTRUCT_MODEL = "gpt-3.5-turbo-instruct" +const DefaultInstructModel = "gpt-3.5-turbo-instruct" + +const StableCodeModelPrefix = "stable-code" type GPTMessage struct { Role string `json:"role"` @@ -162,8 +163,11 @@ type config struct { ChatApiKey string `json:"chat_api_key"` ChatApiOrganization string `json:"chat_api_organization"` ChatApiProject string `json:"chat_api_project"` + ChatMaxTokens int `json:"chat_max_tokens"` ChatModelDefault string `json:"chat_model_default"` ChatModelMap map[string]string `json:"chat_model_map"` + ChatLocale string `json:"chat_locale"` + AuthToken string `json:"auth_token"` } func readConfig() *config { @@ -214,6 +218,9 @@ func readConfig() *config { } } } + if _cfg.CodeInstructModel == "" { + _cfg.CodeInstructModel = DefaultInstructModel + } return _cfg } @@ -261,9 +268,8 @@ func closeIO(c io.Closer) { } type ProxyService struct { - cfg *config - client *http.Client - tokenizer *tiktoken.Tiktoken + cfg *config + client *http.Client } func NewProxyService(cfg *config) (*ProxyService, error) { @@ -272,21 +278,36 @@ func NewProxyService(cfg *config) (*ProxyService, error) { return nil, err } - tokenizer, err := tiktoken.EncodingForModel(INSTRUCT_MODEL) - if nil != err { - return nil, err - } - return &ProxyService{ - cfg: cfg, - client: client, - tokenizer: tokenizer, + cfg: cfg, + client: client, }, nil } +func AuthMiddleware(authToken string) gin.HandlerFunc { + return func(c *gin.Context) { + token := c.Param("token") + if token != authToken { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + c.Abort() + return + } + c.Next() + } +} func (s *ProxyService) InitRoutes(e *gin.Engine) { - e.POST("/v1/chat/completions", s.completions) - e.POST("/v1/engines/copilot-codex/completions", s.codeCompletions) + authToken := s.cfg.AuthToken // replace with your dynamic value as needed + if authToken != "" { + // 鉴权 + v1 := e.Group("/:token/v1/", AuthMiddleware(authToken)) + { + v1.POST("/chat/completions", s.completions) + v1.POST("/engines/copilot-codex/completions", s.codeCompletions) + } + } else { + e.POST("/v1/chat/completions", s.completions) + e.POST("/v1/engines/copilot-codex/completions", s.codeCompletions) + } } func (s *ProxyService) completions(c *gin.Context) { @@ -305,7 +326,26 @@ func (s *ProxyService) completions(c *gin.Context) { model = s.cfg.ChatModelDefault } body, _ = sjson.SetBytes(body, "model", model) + + if !gjson.GetBytes(body, "function_call").Exists() { + messages := gjson.GetBytes(body, "messages").Array() + lastIndex := len(messages) - 1 + if !strings.Contains(messages[lastIndex].Get("content").String(), "Respond in the following locale") { + locale := s.cfg.ChatLocale + if locale == "" { + locale = "zh_CN" + } + body, _ = sjson.SetBytes(body, "messages."+strconv.Itoa(lastIndex)+".content", messages[lastIndex].Get("content").String()+"Respond in the following locale: "+locale+".") + } + } + body, _ = sjson.DeleteBytes(body, "intent") + body, _ = sjson.DeleteBytes(body, "intent_threshold") + body, _ = sjson.DeleteBytes(body, "intent_content") + + if int(gjson.GetBytes(body, "max_tokens").Int()) > s.cfg.ChatMaxTokens { + body, _ = sjson.SetBytes(body, "max_tokens", s.cfg.ChatMaxTokens) + } proxyUrl := s.cfg.ChatApiBase + "/chat/completions" req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body))) @@ -367,7 +407,6 @@ func (s *ProxyService) codeCompletions(c *gin.Context) { abortCodex(c, http.StatusBadRequest) return } - prompt := gjson.GetBytes(body, "prompt").String() suffix := gjson.GetBytes(body, "suffix").String() inputTokens := len(s.tokenizer.Encode(prompt, nil, nil)) @@ -423,6 +462,7 @@ func (s *ProxyService) codeCompletions(c *gin.Context) { } req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body))) if nil != err { + // abortCodex(c, http.StatusInternalServerError) return } @@ -534,6 +574,47 @@ func (s *ProxyService) codeCompletions(c *gin.Context) { _ = resp.Body.Close() } +func ConstructRequestBody(body []byte, cfg *config) []byte { + body, _ = sjson.DeleteBytes(body, "extra") + body, _ = sjson.DeleteBytes(body, "nwo") + body, _ = sjson.SetBytes(body, "model", cfg.CodeInstructModel) + if strings.Contains(cfg.CodeInstructModel, StableCodeModelPrefix) { + return constructWithStableCodeModel(body) + } + if strings.HasSuffix(cfg.ChatApiBase, "chat") { + // @Todo constructWithChatModel + // 如果code base以chat结尾则构建chatModel,暂时没有好的prompt + } + return body +} + +func constructWithStableCodeModel(body []byte) []byte { + suffix := gjson.GetBytes(body, "suffix") + prompt := gjson.GetBytes(body, "prompt") + content := fmt.Sprintf("%s%s", prompt, suffix) + + // 创建新的 JSON 对象并添加到 body 中 + messages := []map[string]string{ + { + "role": "user", + "content": content, + }, + } + return constructWithChatModel(body, messages) +} + +func constructWithChatModel(body []byte, messages interface{}) []byte { + + body, _ = sjson.SetBytes(body, "messages", messages) + + // fmt.Printf("Request Body: %s\n", body) + // 2. 将转义的字符替换回原来的字符 + jsonStr := string(body) + jsonStr = strings.ReplaceAll(jsonStr, "\\u003c", "<") + jsonStr = strings.ReplaceAll(jsonStr, "\\u003e", ">") + return []byte(jsonStr) +} + func main() { cfg := readConfig() @@ -553,4 +634,5 @@ func main() { log.Fatal(err) return } + } diff --git a/scripts/replace_max_tokens.vbs b/scripts/replace_max_tokens.vbs new file mode 100644 index 0000000..06e61ba --- /dev/null +++ b/scripts/replace_max_tokens.vbs @@ -0,0 +1,49 @@ +' VBScript to change max tokens to 2048 + +MsgBox "It may take a few seconds to execute this script." & vbCrLf & vbCrLf & "Click 'OK' button and wait for the prompt of 'Done.' to pop up!" + +Const ForReading = 1 +Const ForWriting = 2 + +' Subpath of the file to be replaced +subpath = "dist\extension.js" + +pattern = "\.maxPromptCompletionTokens\(([a-zA-Z0-9_]+),([0-9]+)\)" +replacement = ".maxPromptCompletionTokens($1,2048)" + +' Iterate over all github copilot directories +Set objFSO = CreateObject("Scripting.FileSystemObject") +Set objShell = CreateObject("WScript.Shell") +Set colExtensions = objFSO.GetFolder(objShell.ExpandEnvironmentStrings("%USERPROFILE%") & "\.vscode\extensions").SubFolders + +For Each objExtension In colExtensions + extension_path = objExtension.Path & "\" & subpath + If objFSO.FileExists(extension_path) Then + backupfile = extension_path & ".bak" + + ' Delete if backup file exists + If objFSO.FileExists(backupfile) Then + objFSO.DeleteFile backupfile, True + End If + + ' Backup + objFSO.CopyFile extension_path, backupfile + + ' Do search and replace with pattern + Set objFile = objFSO.OpenTextFile(extension_path, ForReading) + strContent = objFile.ReadAll + objFile.Close + + Set objRegEx = New RegExp + objRegEx.Global = True + objRegEx.IgnoreCase = True + objRegEx.Pattern = pattern + strContent = objRegEx.Replace(strContent, replacement) + + Set objFile = objFSO.OpenTextFile(extension_path, ForWriting) + objFile.Write strContent + objFile.Close + End If +Next + +MsgBox "Max tokens modification completed" diff --git a/scripts/restore_max_tokens.vbs b/scripts/restore_max_tokens.vbs new file mode 100644 index 0000000..3a10b8e --- /dev/null +++ b/scripts/restore_max_tokens.vbs @@ -0,0 +1,30 @@ +' VBScript to recovery max tokens +MsgBox "It may take a few seconds to execute this script." & vbCrLf & vbCrLf & "Click 'OK' button and wait for the prompt of 'Done.' to pop up!" + +Const ForReading = 1 +Const ForWriting = 2 + +' Subpath of the file to be recovery +subpath = "dist\extension.js" + +' Iterate over all github copilot directories +Set objFSO = CreateObject("Scripting.FileSystemObject") +Set objShell = CreateObject("WScript.Shell") +Set colExtensions = objFSO.GetFolder(objShell.ExpandEnvironmentStrings("%USERPROFILE%") & "\.vscode\extensions").SubFolders + +For Each objExtension In colExtensions + extension_path = objExtension.Path & "\" & subpath + backupfile = extension_path & ".bak" + + If objFSO.FileExists(backupfile) Then + ' Delete if exist extension file + If objFSO.FileExists(extension_path) Then + objFSO.DeleteFile extension_path, True + End If + + ' Replace + objFSO.MoveFile backupfile, extension_path + End If +Next + +MsgBox "Restore max tokens to default successed"