Signed-off-by: wozulong <>
This commit is contained in:
wozulong 2024-05-17 18:36:40 +08:00
parent 38faf0bb1b
commit 7be0c5ecd2
1 changed files with 16 additions and 6 deletions

22
main.go
View File

@ -21,7 +21,7 @@ import (
"time" "time"
) )
const INSTRUCT_MODEL = "gpt-3.5-turbo-instruct" const InstructModel = "gpt-3.5-turbo-instruct"
type config struct { type config struct {
Bind string `json:"bind"` Bind string `json:"bind"`
@ -146,7 +146,7 @@ func NewProxyService(cfg *config) (*ProxyService, error) {
return nil, err return nil, err
} }
tokenizer, err := tiktoken.EncodingForModel(INSTRUCT_MODEL) tokenizer, err := tiktoken.EncodingForModel(InstructModel)
if nil != err { if nil != err {
return nil, err return nil, err
} }
@ -180,6 +180,8 @@ func (s *ProxyService) completions(c *gin.Context) {
} }
body, _ = sjson.SetBytes(body, "model", model) body, _ = sjson.SetBytes(body, "model", model)
body, _ = sjson.DeleteBytes(body, "intent") body, _ = sjson.DeleteBytes(body, "intent")
body, _ = sjson.DeleteBytes(body, "intent_threshold")
body, _ = sjson.DeleteBytes(body, "intent_content")
proxyUrl := s.cfg.ChatApiBase + "/chat/completions" proxyUrl := s.cfg.ChatApiBase + "/chat/completions"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body))) req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body)))
@ -227,6 +229,14 @@ func (s *ProxyService) completions(c *gin.Context) {
_, _ = io.Copy(c.Writer, resp.Body) _, _ = io.Copy(c.Writer, resp.Body)
} }
func (s *ProxyService) countToken(token string) int {
if "" == token {
return 0
}
return len(s.tokenizer.Encode(token, nil, nil))
}
func (s *ProxyService) codeCompletions(c *gin.Context) { func (s *ProxyService) codeCompletions(c *gin.Context) {
ctx := c.Request.Context() ctx := c.Request.Context()
@ -244,8 +254,8 @@ func (s *ProxyService) codeCompletions(c *gin.Context) {
prompt := gjson.GetBytes(body, "prompt").String() prompt := gjson.GetBytes(body, "prompt").String()
suffix := gjson.GetBytes(body, "suffix").String() suffix := gjson.GetBytes(body, "suffix").String()
inputTokens := len(s.tokenizer.Encode(prompt, nil, nil)) inputTokens := s.countToken(prompt)
suffixTokens := len(s.tokenizer.Encode(suffix, nil, nil)) suffixTokens := s.countToken(suffix)
outputTokens := int(gjson.GetBytes(body, "max_tokens").Int()) outputTokens := int(gjson.GetBytes(body, "max_tokens").Int())
totalTokens := inputTokens + suffixTokens + outputTokens totalTokens := inputTokens + suffixTokens + outputTokens
@ -254,7 +264,7 @@ func (s *ProxyService) codeCompletions(c *gin.Context) {
for left < right { for left < right {
mid := (left + right) / 2 mid := (left + right) / 2
subPrompt := prompt[mid:] subPrompt := prompt[mid:]
subInputTokens := len(s.tokenizer.Encode(subPrompt, nil, nil)) subInputTokens := s.countToken(subPrompt)
totalTokens = subInputTokens + suffixTokens + outputTokens totalTokens = subInputTokens + suffixTokens + outputTokens
if totalTokens > s.cfg.CodexMaxTokens { if totalTokens > s.cfg.CodexMaxTokens {
left = mid + 1 left = mid + 1
@ -268,7 +278,7 @@ func (s *ProxyService) codeCompletions(c *gin.Context) {
body, _ = sjson.DeleteBytes(body, "extra") body, _ = sjson.DeleteBytes(body, "extra")
body, _ = sjson.DeleteBytes(body, "nwo") body, _ = sjson.DeleteBytes(body, "nwo")
body, _ = sjson.SetBytes(body, "model", INSTRUCT_MODEL) body, _ = sjson.SetBytes(body, "model", InstructModel)
proxyUrl := s.cfg.CodexApiBase + "/completions" proxyUrl := s.cfg.CodexApiBase + "/completions"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body))) req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body)))