From 7be0c5ecd282ec583ff7cfc91579ce4d46552e2b Mon Sep 17 00:00:00 2001 From: wozulong <> Date: Fri, 17 May 2024 18:36:40 +0800 Subject: [PATCH] fix chat Signed-off-by: wozulong <> --- main.go | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/main.go b/main.go index c89d734..e387ac2 100644 --- a/main.go +++ b/main.go @@ -21,7 +21,7 @@ import ( "time" ) -const INSTRUCT_MODEL = "gpt-3.5-turbo-instruct" +const InstructModel = "gpt-3.5-turbo-instruct" type config struct { Bind string `json:"bind"` @@ -146,7 +146,7 @@ func NewProxyService(cfg *config) (*ProxyService, error) { return nil, err } - tokenizer, err := tiktoken.EncodingForModel(INSTRUCT_MODEL) + tokenizer, err := tiktoken.EncodingForModel(InstructModel) if nil != err { return nil, err } @@ -180,6 +180,8 @@ func (s *ProxyService) completions(c *gin.Context) { } body, _ = sjson.SetBytes(body, "model", model) body, _ = sjson.DeleteBytes(body, "intent") + body, _ = sjson.DeleteBytes(body, "intent_threshold") + body, _ = sjson.DeleteBytes(body, "intent_content") proxyUrl := s.cfg.ChatApiBase + "/chat/completions" req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body))) @@ -227,6 +229,14 @@ func (s *ProxyService) completions(c *gin.Context) { _, _ = io.Copy(c.Writer, resp.Body) } +func (s *ProxyService) countToken(token string) int { + if "" == token { + return 0 + } + + return len(s.tokenizer.Encode(token, nil, nil)) +} + func (s *ProxyService) codeCompletions(c *gin.Context) { ctx := c.Request.Context() @@ -244,8 +254,8 @@ func (s *ProxyService) codeCompletions(c *gin.Context) { prompt := gjson.GetBytes(body, "prompt").String() suffix := gjson.GetBytes(body, "suffix").String() - inputTokens := len(s.tokenizer.Encode(prompt, nil, nil)) - suffixTokens := len(s.tokenizer.Encode(suffix, nil, nil)) + inputTokens := s.countToken(prompt) + suffixTokens := s.countToken(suffix) outputTokens := int(gjson.GetBytes(body, "max_tokens").Int()) totalTokens := inputTokens + suffixTokens + outputTokens @@ -254,7 +264,7 @@ func (s *ProxyService) codeCompletions(c *gin.Context) { for left < right { mid := (left + right) / 2 subPrompt := prompt[mid:] - subInputTokens := len(s.tokenizer.Encode(subPrompt, nil, nil)) + subInputTokens := s.countToken(subPrompt) totalTokens = subInputTokens + suffixTokens + outputTokens if totalTokens > s.cfg.CodexMaxTokens { left = mid + 1 @@ -268,7 +278,7 @@ func (s *ProxyService) codeCompletions(c *gin.Context) { body, _ = sjson.DeleteBytes(body, "extra") body, _ = sjson.DeleteBytes(body, "nwo") - body, _ = sjson.SetBytes(body, "model", INSTRUCT_MODEL) + body, _ = sjson.SetBytes(body, "model", InstructModel) proxyUrl := s.cfg.CodexApiBase + "/completions" req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body)))