add stable-code-3b local model support (#30)
* add stable-code-3b local model support * add stable-code-3b local model support * add stable-code-3b local model support * add stable-code-3b local model support * fix code struct add chat model todo
This commit is contained in:
parent
4075c558ec
commit
c9e7d75fec
13
README.md
13
README.md
|
|
@ -34,13 +34,15 @@
|
||||||
"codex_api_key": "sk-xxx",
|
"codex_api_key": "sk-xxx",
|
||||||
"codex_api_organization": "",
|
"codex_api_organization": "",
|
||||||
"codex_api_project": "",
|
"codex_api_project": "",
|
||||||
|
"code_instruct_model": "gpt-3.5-turbo-instruct",
|
||||||
"chat_api_base": "https://api-proxy.oaipro.com/v1",
|
"chat_api_base": "https://api-proxy.oaipro.com/v1",
|
||||||
"chat_api_key": "sk-xxx",
|
"chat_api_key": "sk-xxx",
|
||||||
"chat_api_organization": "",
|
"chat_api_organization": "",
|
||||||
"chat_api_project": "",
|
"chat_api_project": "",
|
||||||
"chat_max_tokens": 4096,
|
"chat_max_tokens": 4096,
|
||||||
"chat_model_default": "gpt-4o",
|
"chat_model_default": "gpt-4o",
|
||||||
"chat_model_map": {}
|
"chat_model_map": {},
|
||||||
|
"auth_token": ""
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -52,6 +54,15 @@
|
||||||
|
|
||||||
可以通过 `OVERRIDE_` + 大写配置项作为环境变量,可以覆盖 `config.json` 中的值。例如:`OVERRIDE_CODEX_API_KEY=sk-xxxx`
|
可以通过 `OVERRIDE_` + 大写配置项作为环境变量,可以覆盖 `config.json` 中的值。例如:`OVERRIDE_CODEX_API_KEY=sk-xxxx`
|
||||||
|
|
||||||
|
### 本地大模型设置
|
||||||
|
1. 安装ollama
|
||||||
|
2. ollama run stable-code:code (这个模型较小,大部分显卡都能跑)
|
||||||
|
或者你的显卡比较高安装这个:ollama run stable-code:3b-code-fp16
|
||||||
|
3. 修改config.json里面的codex_api_base为http://localhost:11434/v1/chat
|
||||||
|
4. 修改code_instruct_model为你的模型名称,stable-code:code或者stable-code:3b-code-fp16
|
||||||
|
4. 剩下的就按照正常流程走即可。
|
||||||
|
5. 如果调不通,请确认http://localhost:11434/v1/chat可用。
|
||||||
|
|
||||||
### 重要说明
|
### 重要说明
|
||||||
`codex_max_tokens` 工作并不完美,已经移除。**JetBrains IDE 完美工作**,`VSCode` 需要执行以下脚本Patch之:
|
`codex_max_tokens` 工作并不完美,已经移除。**JetBrains IDE 完美工作**,`VSCode` 需要执行以下脚本Patch之:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
package main
|
||||||
82
main.go
82
main.go
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
|
|
@ -20,7 +21,9 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const InstructModel = "gpt-3.5-turbo-instruct"
|
const DefaultInstructModel = "gpt-3.5-turbo-instruct"
|
||||||
|
|
||||||
|
const StableCodeModelPrefix = "stable-code"
|
||||||
|
|
||||||
type config struct {
|
type config struct {
|
||||||
Bind string `json:"bind"`
|
Bind string `json:"bind"`
|
||||||
|
|
@ -30,6 +33,7 @@ type config struct {
|
||||||
CodexApiKey string `json:"codex_api_key"`
|
CodexApiKey string `json:"codex_api_key"`
|
||||||
CodexApiOrganization string `json:"codex_api_organization"`
|
CodexApiOrganization string `json:"codex_api_organization"`
|
||||||
CodexApiProject string `json:"codex_api_project"`
|
CodexApiProject string `json:"codex_api_project"`
|
||||||
|
CodeInstructModel string `json:"code_instruct_model"`
|
||||||
ChatApiBase string `json:"chat_api_base"`
|
ChatApiBase string `json:"chat_api_base"`
|
||||||
ChatApiKey string `json:"chat_api_key"`
|
ChatApiKey string `json:"chat_api_key"`
|
||||||
ChatApiOrganization string `json:"chat_api_organization"`
|
ChatApiOrganization string `json:"chat_api_organization"`
|
||||||
|
|
@ -38,6 +42,7 @@ type config struct {
|
||||||
ChatModelDefault string `json:"chat_model_default"`
|
ChatModelDefault string `json:"chat_model_default"`
|
||||||
ChatModelMap map[string]string `json:"chat_model_map"`
|
ChatModelMap map[string]string `json:"chat_model_map"`
|
||||||
ChatLocale string `json:"chat_locale"`
|
ChatLocale string `json:"chat_locale"`
|
||||||
|
AuthToken string `json:"auth_token"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func readConfig() *config {
|
func readConfig() *config {
|
||||||
|
|
@ -88,6 +93,9 @@ func readConfig() *config {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if _cfg.CodeInstructModel == "" {
|
||||||
|
_cfg.CodeInstructModel = DefaultInstructModel
|
||||||
|
}
|
||||||
|
|
||||||
return _cfg
|
return _cfg
|
||||||
}
|
}
|
||||||
|
|
@ -150,10 +158,31 @@ func NewProxyService(cfg *config) (*ProxyService, error) {
|
||||||
client: client,
|
client: client,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
func AuthMiddleware(authToken string) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
token := c.Param("token")
|
||||||
|
if token != authToken {
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *ProxyService) InitRoutes(e *gin.Engine) {
|
func (s *ProxyService) InitRoutes(e *gin.Engine) {
|
||||||
e.POST("/v1/chat/completions", s.completions)
|
authToken := s.cfg.AuthToken // replace with your dynamic value as needed
|
||||||
e.POST("/v1/engines/copilot-codex/completions", s.codeCompletions)
|
if authToken != "" {
|
||||||
|
// 鉴权
|
||||||
|
v1 := e.Group("/:token/v1/", AuthMiddleware(authToken))
|
||||||
|
{
|
||||||
|
v1.POST("/chat/completions", s.completions)
|
||||||
|
v1.POST("/engines/copilot-codex/completions", s.codeCompletions)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
e.POST("/v1/chat/completions", s.completions)
|
||||||
|
e.POST("/v1/engines/copilot-codex/completions", s.codeCompletions)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyService) completions(c *gin.Context) {
|
func (s *ProxyService) completions(c *gin.Context) {
|
||||||
|
|
@ -254,13 +283,12 @@ func (s *ProxyService) codeCompletions(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
body, _ = sjson.DeleteBytes(body, "extra")
|
body = ConstructRequestBody(body, s.cfg)
|
||||||
body, _ = sjson.DeleteBytes(body, "nwo")
|
|
||||||
body, _ = sjson.SetBytes(body, "model", InstructModel)
|
|
||||||
|
|
||||||
proxyUrl := s.cfg.CodexApiBase + "/completions"
|
proxyUrl := s.cfg.CodexApiBase + "/completions"
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body)))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body)))
|
||||||
if nil != err {
|
if nil != err {
|
||||||
|
//
|
||||||
abortCodex(c, http.StatusInternalServerError)
|
abortCodex(c, http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -305,6 +333,47 @@ func (s *ProxyService) codeCompletions(c *gin.Context) {
|
||||||
_, _ = io.Copy(c.Writer, resp.Body)
|
_, _ = io.Copy(c.Writer, resp.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ConstructRequestBody(body []byte, cfg *config) []byte {
|
||||||
|
body, _ = sjson.DeleteBytes(body, "extra")
|
||||||
|
body, _ = sjson.DeleteBytes(body, "nwo")
|
||||||
|
body, _ = sjson.SetBytes(body, "model", cfg.CodeInstructModel)
|
||||||
|
if strings.Contains(cfg.CodeInstructModel, StableCodeModelPrefix) {
|
||||||
|
return constructWithStableCodeModel(body)
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(cfg.ChatApiBase, "chat") {
|
||||||
|
// @Todo constructWithChatModel
|
||||||
|
// 如果code base以chat结尾则构建chatModel,暂时没有好的prompt
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
func constructWithStableCodeModel(body []byte) []byte {
|
||||||
|
suffix := gjson.GetBytes(body, "suffix")
|
||||||
|
prompt := gjson.GetBytes(body, "prompt")
|
||||||
|
content := fmt.Sprintf("<fim_prefix>%s<fim_suffix>%s<fim_middle>", prompt, suffix)
|
||||||
|
|
||||||
|
// 创建新的 JSON 对象并添加到 body 中
|
||||||
|
messages := []map[string]string{
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": content,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return constructWithChatModel(body, messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
func constructWithChatModel(body []byte, messages interface{}) []byte {
|
||||||
|
|
||||||
|
body, _ = sjson.SetBytes(body, "messages", messages)
|
||||||
|
|
||||||
|
// fmt.Printf("Request Body: %s\n", body)
|
||||||
|
// 2. 将转义的字符替换回原来的字符
|
||||||
|
jsonStr := string(body)
|
||||||
|
jsonStr = strings.ReplaceAll(jsonStr, "\\u003c", "<")
|
||||||
|
jsonStr = strings.ReplaceAll(jsonStr, "\\u003e", ">")
|
||||||
|
return []byte(jsonStr)
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
cfg := readConfig()
|
cfg := readConfig()
|
||||||
|
|
||||||
|
|
@ -324,4 +393,5 @@ func main() {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue