diff --git a/.gitignore b/.gitignore index 48180b1..a7fc7d3 100644 --- a/.gitignore +++ b/.gitignore @@ -33,4 +33,5 @@ bin/ # Develop tools .vscode/ .idea/ -*.swp \ No newline at end of file +*.swp +config.json \ No newline at end of file diff --git a/README.md b/README.md index 3da5661..66215e4 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,7 @@ 3. 我目前就试了下 `VSCode` ,至于 `JetBrains` 等IDE尚未适配,如果你有相关经验,请告诉我。 4. 项目基于 `MIT` 协议发布,你可以修改,请保留原作者信息。 5. 有什么问题,请在论坛 https://linux.do 讨论,欢迎PR。 +6. 可以支持cf,目前最大token 1500左右 配置里字段名是 `"codex_model_default":"@hf/thebloke/deepseek-coder-6.7b-instruct-awq"` ### Star History diff --git a/config.json b/config.json index 60bceb9..fb08ead 100644 --- a/config.json +++ b/config.json @@ -2,11 +2,12 @@ "bind": "127.0.0.1:8181", "proxy_url": "", "timeout": 600, - "codex_api_base": "https://api-proxy.oaipro.com/v1/completions", - "codex_api_key": "sk-xxx", + "codex_api_base": "https://api.oaipro.com/v1/completions", + "codex_api_key": "sk-4VTExcjhhlexSBWXCfDd15A102094580B9616062E0Cb6319", "codex_api_organization": "", "codex_api_project": "", "codex_max_tokens": 4093, + "codex_model_default":"gpt-3.5-turbo-instruct", "chat_api_base": "https://api-proxy.oaipro.com/v1", "chat_api_key": "sk-xxx", "chat_api_organization": "", diff --git a/main.go b/main.go index c89d734..8104c02 100644 --- a/main.go +++ b/main.go @@ -1,10 +1,12 @@ package main import ( + "bufio" "bytes" "context" "encoding/json" "errors" + "fmt" "github.com/gin-gonic/gin" "github.com/linux-do/tiktoken-go" "github.com/tidwall/gjson" @@ -23,6 +25,129 @@ import ( const INSTRUCT_MODEL = "gpt-3.5-turbo-instruct" +type GPTMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} +type StreamResponse struct { + Response string `json:"response"` +} +type Message struct { + Role string `json:"role,omitempty"` + Content any `json:"content,omitempty"` + Name *string `json:"name,omitempty"` +} +type ChatCompletionsStreamResponseChoice struct { + Index int `json:"index"` + Delta Message `json:"delta"` + FinishReason *string `json:"finish_reason,omitempty"` +} + +type ChatCompletionsStreamResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionsStreamResponseChoice `json:"choices"` +} +type CustomEvent struct { + Event string + Id string + Retry uint + Data interface{} +} +type stringWriter interface { + io.Writer + writeString(string) (int, error) +} + +type stringWrapper struct { + io.Writer +} + +var dataReplacer = strings.NewReplacer( + "\n", "\ndata:", + "\r", "\\r") +var contentType = []string{"text/event-stream"} +var noCache = []string{"no-cache"} + +func (w stringWrapper) writeString(str string) (int, error) { + return w.Writer.Write([]byte(str)) +} +func checkWriter(writer io.Writer) stringWriter { + if w, ok := writer.(stringWriter); ok { + return w + } else { + return stringWrapper{writer} + } +} +func encode(writer io.Writer, event CustomEvent) error { + w := checkWriter(writer) + return writeData(w, event.Data) +} +func writeData(w stringWriter, data interface{}) error { + dataReplacer.WriteString(w, fmt.Sprint(data)) + if strings.HasPrefix(data.(string), "data") { + w.writeString("\n\n") + } + return nil +} +func (r CustomEvent) Render(w http.ResponseWriter) error { + r.WriteContentType(w) + return encode(w, r) +} + +func (r CustomEvent) WriteContentType(w http.ResponseWriter) { + header := w.Header() + header["Content-Type"] = contentType + + if _, exist := header["Cache-Control"]; !exist { + header["Cache-Control"] = noCache + } +} + +func GetTimestamp() int64 { + return time.Now().Unix() +} +func StreamResponseCloudflare2OpenAI(cloudflareResponse *StreamResponse) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = cloudflareResponse.Response + choice.Delta.Role = "assistant" + openaiResponse := ChatCompletionsStreamResponse{ + Object: "chat.completion.chunk", + Choices: []ChatCompletionsStreamResponseChoice{choice}, + Created: GetTimestamp(), + } + return &openaiResponse +} +func StreamResponse2OpenAI(cloudflareResponse *StreamResponse) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = cloudflareResponse.Response + choice.Delta.Role = "assistant" + openaiResponse := ChatCompletionsStreamResponse{ + Object: "chat.completion.chunk", + Choices: []ChatCompletionsStreamResponseChoice{choice}, + Created: GetTimestamp(), + } + return &openaiResponse +} +func SetEventStreamHeaders(c *gin.Context) { + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") +} + +const ( + RequestIdKey = "X-Oneapi-Request-Id" +) + +func GetResponseID(c *gin.Context) string { + logID := c.GetString(RequestIdKey) + return fmt.Sprintf("chatcmpl-%s", logID) +} + type config struct { Bind string `json:"bind"` ProxyUrl string `json:"proxy_url"` @@ -32,6 +157,7 @@ type config struct { CodexApiOrganization string `json:"codex_api_organization"` CodexApiProject string `json:"codex_api_project"` CodexMaxTokens int `json:"codex_max_tokens"` + CodexModelDefault string `json:"codex_model_default"` ChatApiBase string `json:"chat_api_base"` ChatApiKey string `json:"chat_api_key"` ChatApiOrganization string `json:"chat_api_organization"` @@ -268,9 +394,33 @@ func (s *ProxyService) codeCompletions(c *gin.Context) { body, _ = sjson.DeleteBytes(body, "extra") body, _ = sjson.DeleteBytes(body, "nwo") - body, _ = sjson.SetBytes(body, "model", INSTRUCT_MODEL) - - proxyUrl := s.cfg.CodexApiBase + "/completions" + var model string + proxyUrl := s.cfg.CodexApiBase + if s.cfg.CodexModelDefault == "" || s.cfg.CodexModelDefault == INSTRUCT_MODEL { + model = INSTRUCT_MODEL + proxyUrl = proxyUrl + "/completions" + } else { + model = s.cfg.CodexModelDefault + } + body, _ = sjson.SetBytes(body, "model", model) + if model == "deepseek-coder" { + message := gjson.GetBytes(body, "prompt").String() + body, _ = sjson.DeleteBytes(body, "prompt") + msg := make([]GPTMessage, 0) + msg = append(msg, GPTMessage{Role: "system", Content: "You are a helpful assistant"}) + msg = append(msg, GPTMessage{Role: "user", Content: message}) + body, _ = sjson.SetBytes(body, "messages", msg) + body, _ = sjson.DeleteBytes(body, "n") + } else if strings.HasPrefix(model, "@") { + proxyUrl = s.cfg.CodexApiBase + message := gjson.GetBytes(body, "prompt").String() + body, _ = sjson.DeleteBytes(body, "prompt") + msg := make([]GPTMessage, 0) + msg = append(msg, GPTMessage{Role: "system", Content: ""}) + msg = append(msg, GPTMessage{Role: "user", Content: message}) + body, _ = sjson.SetBytes(body, "messages", msg) + body, _ = sjson.DeleteBytes(body, "n") + } req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body))) if nil != err { abortCodex(c, http.StatusInternalServerError) @@ -314,7 +464,74 @@ func (s *ProxyService) codeCompletions(c *gin.Context) { c.Header("Content-Type", contentType) } - _, _ = io.Copy(c.Writer, resp.Body) + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := bytes.IndexByte(data, '\n'); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < len("data: ") { + continue + } + data = strings.TrimPrefix(data, "data: ") + dataChan <- data + } + stopChan <- true + }() + SetEventStreamHeaders(c) + id := GetResponseID(c) + responseModel := c.GetString("original_model") + var responseText string + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + // some implementations may add \r at the end of data + data = strings.TrimSuffix(data, "\r") + var codeResponse StreamResponse + err := json.Unmarshal([]byte(data), &codeResponse) + if err != nil { + if data == "[DONE]" { + return true + } + log.Println("error unmarshalling stream response: ", err.Error()) + return true + } + if model != INSTRUCT_MODEL { + response := StreamResponseCloudflare2OpenAI(&codeResponse) + if response == nil { + return true + } + responseText += codeResponse.Response + response.Id = id + response.Model = responseModel + jsonStr, err := json.Marshal(response) + if err != nil { + log.Println("error marshalling stream response: ", err.Error()) + return true + } + c.Render(-1, CustomEvent{Data: "data: " + string(jsonStr)}) + } else { + c.Render(-1, CustomEvent{Data: "data:" + string(data)}) + } + return true + case <-stopChan: + c.Render(-1, CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + _ = resp.Body.Close() } func main() {