From bf4957fc1dae7b23391d9a93240f32a7cf170b07 Mon Sep 17 00:00:00 2001 From: wozulong <> Date: Fri, 17 May 2024 11:51:54 +0800 Subject: [PATCH] reduce prompt tokens Signed-off-by: wozulong <> --- config.json | 1 + go.mod | 3 ++ go.sum | 6 ++++ main.go | 85 +++++++++++++++++++++++++++++++++++++++++++++++++---- 4 files changed, 90 insertions(+), 5 deletions(-) diff --git a/config.json b/config.json index a79b721..27a4db5 100644 --- a/config.json +++ b/config.json @@ -6,6 +6,7 @@ "codex_api_key": "sk-xxx", "codex_api_organization": "", "codex_api_project": "", + "codex_max_tokens": 4093, "chat_api_base": "https://api-proxy.oaipro.com/v1", "chat_api_key": "sk-xxx", "chat_api_organization": "", diff --git a/go.mod b/go.mod index 5ef356a..34c8f1d 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ toolchain go1.21.4 require ( github.com/gin-gonic/gin v1.10.0 + github.com/linux-do/tiktoken-go v0.7.0 github.com/tidwall/gjson v1.17.1 github.com/tidwall/sjson v1.2.5 golang.org/x/net v0.25.0 @@ -16,6 +17,7 @@ require ( github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect + github.com/dlclark/regexp2 v1.11.0 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect @@ -23,6 +25,7 @@ require ( github.com/go-playground/validator/v10 v10.20.0 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-cmp v0.5.9 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/kr/pretty v0.3.0 // indirect diff --git a/go.sum b/go.sum index ebce207..e05fea7 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= +github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= @@ -29,6 +31,8 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= @@ -45,6 +49,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/linux-do/tiktoken-go v0.7.0 h1:Kcm/miJ5gp77srtF8GQWnfq7W9kTaXEuHZg/g9IVEu8= +github.com/linux-do/tiktoken-go v0.7.0/go.mod h1:9Vkdtp0ngi4USmrdSx984iuIQ5IMr0hnUdz4jZZTJb8= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= diff --git a/main.go b/main.go index 6a80870..c89d734 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "github.com/gin-gonic/gin" + "github.com/linux-do/tiktoken-go" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "golang.org/x/net/http2" @@ -14,9 +15,14 @@ import ( "net/http" "net/url" "os" + "reflect" + "strconv" + "strings" "time" ) +const INSTRUCT_MODEL = "gpt-3.5-turbo-instruct" + type config struct { Bind string `json:"bind"` ProxyUrl string `json:"proxy_url"` @@ -25,6 +31,7 @@ type config struct { CodexApiKey string `json:"codex_api_key"` CodexApiOrganization string `json:"codex_api_organization"` CodexApiProject string `json:"codex_api_project"` + CodexMaxTokens int `json:"codex_max_tokens"` ChatApiBase string `json:"chat_api_base"` ChatApiKey string `json:"chat_api_key"` ChatApiOrganization string `json:"chat_api_organization"` @@ -45,6 +52,43 @@ func readConfig() *config { log.Fatal(err) } + v := reflect.ValueOf(_cfg).Elem() + t := v.Type() + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + tag := t.Field(i).Tag.Get("json") + if tag == "" { + continue + } + + value, exists := os.LookupEnv("OVERRIDE_" + strings.ToUpper(tag)) + if !exists { + continue + } + + switch field.Kind() { + case reflect.String: + field.SetString(value) + case reflect.Bool: + if boolValue, err := strconv.ParseBool(value); err == nil { + field.SetBool(boolValue) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if intValue, err := strconv.ParseInt(value, 10, 64); err == nil { + field.SetInt(intValue) + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + if uintValue, err := strconv.ParseUint(value, 10, 64); err == nil { + field.SetUint(uintValue) + } + case reflect.Float32, reflect.Float64: + if floatValue, err := strconv.ParseFloat(value, field.Type().Bits()); err == nil { + field.SetFloat(floatValue) + } + } + } + return _cfg } @@ -91,8 +135,9 @@ func closeIO(c io.Closer) { } type ProxyService struct { - cfg *config - client *http.Client + cfg *config + client *http.Client + tokenizer *tiktoken.Tiktoken } func NewProxyService(cfg *config) (*ProxyService, error) { @@ -101,9 +146,15 @@ func NewProxyService(cfg *config) (*ProxyService, error) { return nil, err } + tokenizer, err := tiktoken.EncodingForModel(INSTRUCT_MODEL) + if nil != err { + return nil, err + } + return &ProxyService{ - cfg: cfg, - client: client, + cfg: cfg, + client: client, + tokenizer: tokenizer, }, nil } @@ -191,9 +242,33 @@ func (s *ProxyService) codeCompletions(c *gin.Context) { return } + prompt := gjson.GetBytes(body, "prompt").String() + suffix := gjson.GetBytes(body, "suffix").String() + inputTokens := len(s.tokenizer.Encode(prompt, nil, nil)) + suffixTokens := len(s.tokenizer.Encode(suffix, nil, nil)) + outputTokens := int(gjson.GetBytes(body, "max_tokens").Int()) + + totalTokens := inputTokens + suffixTokens + outputTokens + if totalTokens > s.cfg.CodexMaxTokens { // reduce + left, right := 0, len(prompt) + for left < right { + mid := (left + right) / 2 + subPrompt := prompt[mid:] + subInputTokens := len(s.tokenizer.Encode(subPrompt, nil, nil)) + totalTokens = subInputTokens + suffixTokens + outputTokens + if totalTokens > s.cfg.CodexMaxTokens { + left = mid + 1 + } else { + right = mid + } + } + + body, _ = sjson.SetBytes(body, "prompt", prompt[left:]) + } + body, _ = sjson.DeleteBytes(body, "extra") body, _ = sjson.DeleteBytes(body, "nwo") - body, _ = sjson.SetBytes(body, "model", "gpt-3.5-turbo-instruct") + body, _ = sjson.SetBytes(body, "model", INSTRUCT_MODEL) proxyUrl := s.cfg.CodexApiBase + "/completions" req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body)))