add stable-code-3b local model support

This commit is contained in:
liuzhifei 2024-05-22 15:53:10 +08:00
parent 7992cbe8f2
commit 4749a0945f
1 changed files with 48 additions and 6 deletions

54
main.go
View File

@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@ -20,7 +21,7 @@ import (
"time"
)
const InstructModel = "gpt-3.5-turbo-instruct"
const InstructModel = "stable-code:3b-code-fp16"
type config struct {
Bind string `json:"bind"`
@ -38,6 +39,7 @@ type config struct {
ChatModelDefault string `json:"chat_model_default"`
ChatModelMap map[string]string `json:"chat_model_map"`
ChatLocale string `json:"chat_locale"`
AuthToken string `json:"auth_token"`
}
func readConfig() *config {
@ -150,10 +152,31 @@ func NewProxyService(cfg *config) (*ProxyService, error) {
client: client,
}, nil
}
func AuthMiddleware(authToken string) gin.HandlerFunc {
return func(c *gin.Context) {
token := c.Param("token")
if token != authToken {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
c.Abort()
return
}
c.Next()
}
}
func (s *ProxyService) InitRoutes(e *gin.Engine) {
e.POST("/v1/chat/completions", s.completions)
e.POST("/v1/engines/copilot-codex/completions", s.codeCompletions)
func (s *ProxyService) InitRoutes(e *gin.Engine, cfg *config) {
authToken := cfg.AuthToken // replace with your dynamic value as needed
if authToken != "" {
// 鉴权
v1 := e.Group("/:token/v1/", AuthMiddleware(authToken))
{
v1.POST("/chat/completions", s.completions)
v1.POST("/engines/copilot-codex/completions", s.codeCompletions)
}
} else {
e.POST("/v1/chat/completions", s.completions)
e.POST("/v1/engines/copilot-codex/completions", s.codeCompletions)
}
}
func (s *ProxyService) completions(c *gin.Context) {
@ -256,11 +279,29 @@ func (s *ProxyService) codeCompletions(c *gin.Context) {
body, _ = sjson.DeleteBytes(body, "extra")
body, _ = sjson.DeleteBytes(body, "nwo")
suffix := gjson.GetBytes(body, "suffix")
prompt := gjson.GetBytes(body, "prompt")
content := fmt.Sprintf("<fim_prefix>%s<fim_suffix>%s<fim_middle>", prompt, suffix)
// 创建新的 JSON 对象并添加到 body 中
messages := []map[string]string{
{
"role": "user",
"content": content,
},
}
body, _ = sjson.SetBytes(body, "messages", messages)
body, _ = sjson.SetBytes(body, "model", InstructModel)
// fmt.Printf("Request Body: %s\n", body)
// 2. 将转义的字符替换回原来的字符
jsonStr := string(body)
jsonStr = strings.ReplaceAll(jsonStr, "\\u003c", "<")
jsonStr = strings.ReplaceAll(jsonStr, "\\u003e", ">")
proxyUrl := s.cfg.CodexApiBase + "/completions"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body)))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer([]byte(jsonStr))))
if nil != err {
//
abortCodex(c, http.StatusInternalServerError)
return
}
@ -317,11 +358,12 @@ func main() {
return
}
proxyService.InitRoutes(r)
proxyService.InitRoutes(r, cfg)
err = r.Run(cfg.Bind)
if nil != err {
log.Fatal(err)
return
}
}