Signed-off-by: wozulong <>
This commit is contained in:
wozulong 2024-05-17 18:36:40 +08:00
parent 38faf0bb1b
commit 7be0c5ecd2
1 changed files with 16 additions and 6 deletions

22
main.go
View File

@ -21,7 +21,7 @@ import (
"time"
)
const INSTRUCT_MODEL = "gpt-3.5-turbo-instruct"
const InstructModel = "gpt-3.5-turbo-instruct"
type config struct {
Bind string `json:"bind"`
@ -146,7 +146,7 @@ func NewProxyService(cfg *config) (*ProxyService, error) {
return nil, err
}
tokenizer, err := tiktoken.EncodingForModel(INSTRUCT_MODEL)
tokenizer, err := tiktoken.EncodingForModel(InstructModel)
if nil != err {
return nil, err
}
@ -180,6 +180,8 @@ func (s *ProxyService) completions(c *gin.Context) {
}
body, _ = sjson.SetBytes(body, "model", model)
body, _ = sjson.DeleteBytes(body, "intent")
body, _ = sjson.DeleteBytes(body, "intent_threshold")
body, _ = sjson.DeleteBytes(body, "intent_content")
proxyUrl := s.cfg.ChatApiBase + "/chat/completions"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body)))
@ -227,6 +229,14 @@ func (s *ProxyService) completions(c *gin.Context) {
_, _ = io.Copy(c.Writer, resp.Body)
}
func (s *ProxyService) countToken(token string) int {
if "" == token {
return 0
}
return len(s.tokenizer.Encode(token, nil, nil))
}
func (s *ProxyService) codeCompletions(c *gin.Context) {
ctx := c.Request.Context()
@ -244,8 +254,8 @@ func (s *ProxyService) codeCompletions(c *gin.Context) {
prompt := gjson.GetBytes(body, "prompt").String()
suffix := gjson.GetBytes(body, "suffix").String()
inputTokens := len(s.tokenizer.Encode(prompt, nil, nil))
suffixTokens := len(s.tokenizer.Encode(suffix, nil, nil))
inputTokens := s.countToken(prompt)
suffixTokens := s.countToken(suffix)
outputTokens := int(gjson.GetBytes(body, "max_tokens").Int())
totalTokens := inputTokens + suffixTokens + outputTokens
@ -254,7 +264,7 @@ func (s *ProxyService) codeCompletions(c *gin.Context) {
for left < right {
mid := (left + right) / 2
subPrompt := prompt[mid:]
subInputTokens := len(s.tokenizer.Encode(subPrompt, nil, nil))
subInputTokens := s.countToken(subPrompt)
totalTokens = subInputTokens + suffixTokens + outputTokens
if totalTokens > s.cfg.CodexMaxTokens {
left = mid + 1
@ -268,7 +278,7 @@ func (s *ProxyService) codeCompletions(c *gin.Context) {
body, _ = sjson.DeleteBytes(body, "extra")
body, _ = sjson.DeleteBytes(body, "nwo")
body, _ = sjson.SetBytes(body, "model", INSTRUCT_MODEL)
body, _ = sjson.SetBytes(body, "model", InstructModel)
proxyUrl := s.cfg.CodexApiBase + "/completions"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body)))