From 811aad5bfc26e326a100106b7cb12a7a5d06dba7 Mon Sep 17 00:00:00 2001 From: tkisme Date: Fri, 17 May 2024 15:19:48 +0800 Subject: [PATCH] fallback model --- main.go | 126 +++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 80 insertions(+), 46 deletions(-) diff --git a/main.go b/main.go index ee90d77..c89d734 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "github.com/gin-gonic/gin" + "github.com/linux-do/tiktoken-go" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "golang.org/x/net/http2" @@ -14,10 +15,14 @@ import ( "net/http" "net/url" "os" + "reflect" + "strconv" "strings" "time" ) +const INSTRUCT_MODEL = "gpt-3.5-turbo-instruct" + type config struct { Bind string `json:"bind"` ProxyUrl string `json:"proxy_url"` @@ -26,7 +31,7 @@ type config struct { CodexApiKey string `json:"codex_api_key"` CodexApiOrganization string `json:"codex_api_organization"` CodexApiProject string `json:"codex_api_project"` - CodexModelDefault string `json:"codex_model_default"` + CodexMaxTokens int `json:"codex_max_tokens"` ChatApiBase string `json:"chat_api_base"` ChatApiKey string `json:"chat_api_key"` ChatApiOrganization string `json:"chat_api_organization"` @@ -47,35 +52,46 @@ func readConfig() *config { log.Fatal(err) } + v := reflect.ValueOf(_cfg).Elem() + t := v.Type() + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + tag := t.Field(i).Tag.Get("json") + if tag == "" { + continue + } + + value, exists := os.LookupEnv("OVERRIDE_" + strings.ToUpper(tag)) + if !exists { + continue + } + + switch field.Kind() { + case reflect.String: + field.SetString(value) + case reflect.Bool: + if boolValue, err := strconv.ParseBool(value); err == nil { + field.SetBool(boolValue) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if intValue, err := strconv.ParseInt(value, 10, 64); err == nil { + field.SetInt(intValue) + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + if uintValue, err := strconv.ParseUint(value, 10, 64); err == nil { + field.SetUint(uintValue) + } + case reflect.Float32, reflect.Float64: + if floatValue, err := strconv.ParseFloat(value, field.Type().Bits()); err == nil { + field.SetFloat(floatValue) + } + } + } + return _cfg } -type GPTMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} -type StreamResponse struct { - Response string `json:"response"` -} -type Message struct { - Role string `json:"role,omitempty"` - Content any `json:"content,omitempty"` - Name *string `json:"name,omitempty"` -} -type ChatCompletionsStreamResponseChoice struct { - Index int `json:"index"` - Delta Message `json:"delta"` - FinishReason *string `json:"finish_reason,omitempty"` -} - -type ChatCompletionsStreamResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionsStreamResponseChoice `json:"choices"` -} - func getClient(cfg *config) (*http.Client, error) { transport := &http.Transport{ ForceAttemptHTTP2: true, @@ -119,8 +135,9 @@ func closeIO(c io.Closer) { } type ProxyService struct { - cfg *config - client *http.Client + cfg *config + client *http.Client + tokenizer *tiktoken.Tiktoken } func NewProxyService(cfg *config) (*ProxyService, error) { @@ -129,9 +146,15 @@ func NewProxyService(cfg *config) (*ProxyService, error) { return nil, err } + tokenizer, err := tiktoken.EncodingForModel(INSTRUCT_MODEL) + if nil != err { + return nil, err + } + return &ProxyService{ - cfg: cfg, - client: client, + cfg: cfg, + client: client, + tokenizer: tokenizer, }, nil } @@ -219,24 +242,35 @@ func (s *ProxyService) codeCompletions(c *gin.Context) { return } + prompt := gjson.GetBytes(body, "prompt").String() + suffix := gjson.GetBytes(body, "suffix").String() + inputTokens := len(s.tokenizer.Encode(prompt, nil, nil)) + suffixTokens := len(s.tokenizer.Encode(suffix, nil, nil)) + outputTokens := int(gjson.GetBytes(body, "max_tokens").Int()) + + totalTokens := inputTokens + suffixTokens + outputTokens + if totalTokens > s.cfg.CodexMaxTokens { // reduce + left, right := 0, len(prompt) + for left < right { + mid := (left + right) / 2 + subPrompt := prompt[mid:] + subInputTokens := len(s.tokenizer.Encode(subPrompt, nil, nil)) + totalTokens = subInputTokens + suffixTokens + outputTokens + if totalTokens > s.cfg.CodexMaxTokens { + left = mid + 1 + } else { + right = mid + } + } + + body, _ = sjson.SetBytes(body, "prompt", prompt[left:]) + } + body, _ = sjson.DeleteBytes(body, "extra") body, _ = sjson.DeleteBytes(body, "nwo") - if s.cfg.CodexModelDefault == "" { - s.cfg.CodexModelDefault = "gpt-3.5-turbo-instruct" - } - body, _ = sjson.SetBytes(body, "model", s.cfg.CodexModelDefault) + body, _ = sjson.SetBytes(body, "model", INSTRUCT_MODEL) - proxyUrl := s.cfg.CodexApiBase - if strings.HasPrefix(s.cfg.CodexModelDefault, "@") { - proxyUrl = s.cfg.CodexApiBase - message := gjson.GetBytes(body, "prompt").String() - body, _ = sjson.DeleteBytes(body, "prompt") - msg := make([]GPTMessage, 0) - msg = append(msg, GPTMessage{Role: "system", Content: "You are a helpful assistant"}) - msg = append(msg, GPTMessage{Role: "user", Content: message}) - body, _ = sjson.SetBytes(body, "messages", msg) - body, _ = sjson.DeleteBytes(body, "n") - } + proxyUrl := s.cfg.CodexApiBase + "/completions" req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body))) if nil != err { abortCodex(c, http.StatusInternalServerError)