diff --git a/main.go b/main.go index c89d734..6a80870 100644 --- a/main.go +++ b/main.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "github.com/gin-gonic/gin" - "github.com/linux-do/tiktoken-go" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "golang.org/x/net/http2" @@ -15,14 +14,9 @@ import ( "net/http" "net/url" "os" - "reflect" - "strconv" - "strings" "time" ) -const INSTRUCT_MODEL = "gpt-3.5-turbo-instruct" - type config struct { Bind string `json:"bind"` ProxyUrl string `json:"proxy_url"` @@ -31,7 +25,6 @@ type config struct { CodexApiKey string `json:"codex_api_key"` CodexApiOrganization string `json:"codex_api_organization"` CodexApiProject string `json:"codex_api_project"` - CodexMaxTokens int `json:"codex_max_tokens"` ChatApiBase string `json:"chat_api_base"` ChatApiKey string `json:"chat_api_key"` ChatApiOrganization string `json:"chat_api_organization"` @@ -52,43 +45,6 @@ func readConfig() *config { log.Fatal(err) } - v := reflect.ValueOf(_cfg).Elem() - t := v.Type() - - for i := 0; i < v.NumField(); i++ { - field := v.Field(i) - tag := t.Field(i).Tag.Get("json") - if tag == "" { - continue - } - - value, exists := os.LookupEnv("OVERRIDE_" + strings.ToUpper(tag)) - if !exists { - continue - } - - switch field.Kind() { - case reflect.String: - field.SetString(value) - case reflect.Bool: - if boolValue, err := strconv.ParseBool(value); err == nil { - field.SetBool(boolValue) - } - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if intValue, err := strconv.ParseInt(value, 10, 64); err == nil { - field.SetInt(intValue) - } - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - if uintValue, err := strconv.ParseUint(value, 10, 64); err == nil { - field.SetUint(uintValue) - } - case reflect.Float32, reflect.Float64: - if floatValue, err := strconv.ParseFloat(value, field.Type().Bits()); err == nil { - field.SetFloat(floatValue) - } - } - } - return _cfg } @@ -135,9 +91,8 @@ func closeIO(c io.Closer) { } type ProxyService struct { - cfg *config - client *http.Client - tokenizer *tiktoken.Tiktoken + cfg *config + client *http.Client } func NewProxyService(cfg *config) (*ProxyService, error) { @@ -146,15 +101,9 @@ func NewProxyService(cfg *config) (*ProxyService, error) { return nil, err } - tokenizer, err := tiktoken.EncodingForModel(INSTRUCT_MODEL) - if nil != err { - return nil, err - } - return &ProxyService{ - cfg: cfg, - client: client, - tokenizer: tokenizer, + cfg: cfg, + client: client, }, nil } @@ -242,33 +191,9 @@ func (s *ProxyService) codeCompletions(c *gin.Context) { return } - prompt := gjson.GetBytes(body, "prompt").String() - suffix := gjson.GetBytes(body, "suffix").String() - inputTokens := len(s.tokenizer.Encode(prompt, nil, nil)) - suffixTokens := len(s.tokenizer.Encode(suffix, nil, nil)) - outputTokens := int(gjson.GetBytes(body, "max_tokens").Int()) - - totalTokens := inputTokens + suffixTokens + outputTokens - if totalTokens > s.cfg.CodexMaxTokens { // reduce - left, right := 0, len(prompt) - for left < right { - mid := (left + right) / 2 - subPrompt := prompt[mid:] - subInputTokens := len(s.tokenizer.Encode(subPrompt, nil, nil)) - totalTokens = subInputTokens + suffixTokens + outputTokens - if totalTokens > s.cfg.CodexMaxTokens { - left = mid + 1 - } else { - right = mid - } - } - - body, _ = sjson.SetBytes(body, "prompt", prompt[left:]) - } - body, _ = sjson.DeleteBytes(body, "extra") body, _ = sjson.DeleteBytes(body, "nwo") - body, _ = sjson.SetBytes(body, "model", INSTRUCT_MODEL) + body, _ = sjson.SetBytes(body, "model", "gpt-3.5-turbo-instruct") proxyUrl := s.cfg.CodexApiBase + "/completions" req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body)))