fix code struct add chat model todo

This commit is contained in:
liuzhifei 2024-05-23 14:25:08 +08:00
parent b52728f665
commit 3e37b44c40
2 changed files with 49 additions and 24 deletions

1
localModel.go Normal file
View File

@ -0,0 +1 @@
package main

72
main.go
View File

@ -23,6 +23,8 @@ import (
const DefaultInstructModel = "gpt-3.5-turbo-instruct"
const StableCodeModelPrefix = "stable-code"
type config struct {
Bind string `json:"bind"`
ProxyUrl string `json:"proxy_url"`
@ -168,8 +170,8 @@ func AuthMiddleware(authToken string) gin.HandlerFunc {
}
}
func (s *ProxyService) InitRoutes(e *gin.Engine, cfg *config) {
authToken := cfg.AuthToken // replace with your dynamic value as needed
func (s *ProxyService) InitRoutes(e *gin.Engine) {
authToken := s.cfg.AuthToken // replace with your dynamic value as needed
if authToken != "" {
// 鉴权
v1 := e.Group("/:token/v1/", AuthMiddleware(authToken))
@ -281,29 +283,10 @@ func (s *ProxyService) codeCompletions(c *gin.Context) {
return
}
body, _ = sjson.DeleteBytes(body, "extra")
body, _ = sjson.DeleteBytes(body, "nwo")
suffix := gjson.GetBytes(body, "suffix")
prompt := gjson.GetBytes(body, "prompt")
content := fmt.Sprintf("<fim_prefix>%s<fim_suffix>%s<fim_middle>", prompt, suffix)
body = ConstructRequestBody(body, s.cfg)
// 创建新的 JSON 对象并添加到 body 中
messages := []map[string]string{
{
"role": "user",
"content": content,
},
}
body, _ = sjson.SetBytes(body, "messages", messages)
body, _ = sjson.SetBytes(body, "model", s.cfg.CodeInstructModel)
// fmt.Printf("Request Body: %s\n", body)
// 2. 将转义的字符替换回原来的字符
jsonStr := string(body)
jsonStr = strings.ReplaceAll(jsonStr, "\\u003c", "<")
jsonStr = strings.ReplaceAll(jsonStr, "\\u003e", ">")
proxyUrl := s.cfg.CodexApiBase + "/completions"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer([]byte(jsonStr))))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body)))
if nil != err {
//
abortCodex(c, http.StatusInternalServerError)
@ -350,6 +333,47 @@ func (s *ProxyService) codeCompletions(c *gin.Context) {
_, _ = io.Copy(c.Writer, resp.Body)
}
func ConstructRequestBody(body []byte, cfg *config) []byte {
body, _ = sjson.DeleteBytes(body, "extra")
body, _ = sjson.DeleteBytes(body, "nwo")
body, _ = sjson.SetBytes(body, "model", cfg.CodeInstructModel)
if strings.Contains(cfg.CodeInstructModel, StableCodeModelPrefix) {
return constructWithStableCodeModel(body)
}
if strings.HasSuffix(cfg.ChatApiBase, "chat") {
// @Todo constructWithChatModel
// 如果code base以chat结尾则构建chatModel暂时没有好的prompt
}
return body
}
func constructWithStableCodeModel(body []byte) []byte {
suffix := gjson.GetBytes(body, "suffix")
prompt := gjson.GetBytes(body, "prompt")
content := fmt.Sprintf("<fim_prefix>%s<fim_suffix>%s<fim_middle>", prompt, suffix)
// 创建新的 JSON 对象并添加到 body 中
messages := []map[string]string{
{
"role": "user",
"content": content,
},
}
return constructWithChatModel(body, messages)
}
func constructWithChatModel(body []byte, messages interface{}) []byte {
body, _ = sjson.SetBytes(body, "messages", messages)
// fmt.Printf("Request Body: %s\n", body)
// 2. 将转义的字符替换回原来的字符
jsonStr := string(body)
jsonStr = strings.ReplaceAll(jsonStr, "\\u003c", "<")
jsonStr = strings.ReplaceAll(jsonStr, "\\u003e", ">")
return []byte(jsonStr)
}
func main() {
cfg := readConfig()
@ -362,7 +386,7 @@ func main() {
return
}
proxyService.InitRoutes(r, cfg)
proxyService.InitRoutes(r)
err = r.Run(cfg.Bind)
if nil != err {