override/main.go

557 lines
14 KiB
Go

package main
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/linux-do/tiktoken-go"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"golang.org/x/net/http2"
"io"
"log"
"net/http"
"net/url"
"os"
"reflect"
"strconv"
"strings"
"time"
)
const INSTRUCT_MODEL = "gpt-3.5-turbo-instruct"
type GPTMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type StreamResponse struct {
Response string `json:"response"`
}
type Message struct {
Role string `json:"role,omitempty"`
Content any `json:"content,omitempty"`
Name *string `json:"name,omitempty"`
}
type ChatCompletionsStreamResponseChoice struct {
Index int `json:"index"`
Delta Message `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}
type ChatCompletionsStreamResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
}
type CustomEvent struct {
Event string
Id string
Retry uint
Data interface{}
}
type stringWriter interface {
io.Writer
writeString(string) (int, error)
}
type stringWrapper struct {
io.Writer
}
var dataReplacer = strings.NewReplacer(
"\n", "\ndata:",
"\r", "\\r")
var contentType = []string{"text/event-stream"}
var noCache = []string{"no-cache"}
func (w stringWrapper) writeString(str string) (int, error) {
return w.Writer.Write([]byte(str))
}
func checkWriter(writer io.Writer) stringWriter {
if w, ok := writer.(stringWriter); ok {
return w
} else {
return stringWrapper{writer}
}
}
func encode(writer io.Writer, event CustomEvent) error {
w := checkWriter(writer)
return writeData(w, event.Data)
}
func writeData(w stringWriter, data interface{}) error {
dataReplacer.WriteString(w, fmt.Sprint(data))
if strings.HasPrefix(data.(string), "data") {
w.writeString("\n\n")
}
return nil
}
func (r CustomEvent) Render(w http.ResponseWriter) error {
r.WriteContentType(w)
return encode(w, r)
}
func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
header := w.Header()
header["Content-Type"] = contentType
if _, exist := header["Cache-Control"]; !exist {
header["Cache-Control"] = noCache
}
}
func GetTimestamp() int64 {
return time.Now().Unix()
}
func StreamResponseCloudflare2OpenAI(cloudflareResponse *StreamResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = cloudflareResponse.Response
choice.Delta.Role = "assistant"
openaiResponse := ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Choices: []ChatCompletionsStreamResponseChoice{choice},
Created: GetTimestamp(),
}
return &openaiResponse
}
func StreamResponse2OpenAI(cloudflareResponse *StreamResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = cloudflareResponse.Response
choice.Delta.Role = "assistant"
openaiResponse := ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Choices: []ChatCompletionsStreamResponseChoice{choice},
Created: GetTimestamp(),
}
return &openaiResponse
}
func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
func GetResponseID(c *gin.Context) string {
logID := c.GetString(RequestIdKey)
return fmt.Sprintf("chatcmpl-%s", logID)
}
type config struct {
Bind string `json:"bind"`
ProxyUrl string `json:"proxy_url"`
Timeout int `json:"timeout"`
CodexApiBase string `json:"codex_api_base"`
CodexApiKey string `json:"codex_api_key"`
CodexApiOrganization string `json:"codex_api_organization"`
CodexApiProject string `json:"codex_api_project"`
CodexMaxTokens int `json:"codex_max_tokens"`
CodexModelDefault string `json:"codex_model_default"`
ChatApiBase string `json:"chat_api_base"`
ChatApiKey string `json:"chat_api_key"`
ChatApiOrganization string `json:"chat_api_organization"`
ChatApiProject string `json:"chat_api_project"`
ChatModelDefault string `json:"chat_model_default"`
ChatModelMap map[string]string `json:"chat_model_map"`
}
func readConfig() *config {
content, err := os.ReadFile("config.json")
if nil != err {
log.Fatal(err)
}
_cfg := &config{}
err = json.Unmarshal(content, &_cfg)
if nil != err {
log.Fatal(err)
}
v := reflect.ValueOf(_cfg).Elem()
t := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
tag := t.Field(i).Tag.Get("json")
if tag == "" {
continue
}
value, exists := os.LookupEnv("OVERRIDE_" + strings.ToUpper(tag))
if !exists {
continue
}
switch field.Kind() {
case reflect.String:
field.SetString(value)
case reflect.Bool:
if boolValue, err := strconv.ParseBool(value); err == nil {
field.SetBool(boolValue)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if intValue, err := strconv.ParseInt(value, 10, 64); err == nil {
field.SetInt(intValue)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
if uintValue, err := strconv.ParseUint(value, 10, 64); err == nil {
field.SetUint(uintValue)
}
case reflect.Float32, reflect.Float64:
if floatValue, err := strconv.ParseFloat(value, field.Type().Bits()); err == nil {
field.SetFloat(floatValue)
}
}
}
return _cfg
}
func getClient(cfg *config) (*http.Client, error) {
transport := &http.Transport{
ForceAttemptHTTP2: true,
DisableKeepAlives: false,
}
err := http2.ConfigureTransport(transport)
if nil != err {
return nil, err
}
if "" != cfg.ProxyUrl {
proxyUrl, err := url.Parse(cfg.ProxyUrl)
if nil != err {
return nil, err
}
transport.Proxy = http.ProxyURL(proxyUrl)
}
client := &http.Client{
Transport: transport,
Timeout: time.Duration(cfg.Timeout) * time.Second,
}
return client, nil
}
func abortCodex(c *gin.Context, status int) {
c.Header("Content-Type", "text/event-stream")
c.String(status, "data: [DONE]\n")
c.Abort()
}
func closeIO(c io.Closer) {
err := c.Close()
if nil != err {
log.Println(err)
}
}
type ProxyService struct {
cfg *config
client *http.Client
tokenizer *tiktoken.Tiktoken
}
func NewProxyService(cfg *config) (*ProxyService, error) {
client, err := getClient(cfg)
if nil != err {
return nil, err
}
tokenizer, err := tiktoken.EncodingForModel(INSTRUCT_MODEL)
if nil != err {
return nil, err
}
return &ProxyService{
cfg: cfg,
client: client,
tokenizer: tokenizer,
}, nil
}
func (s *ProxyService) InitRoutes(e *gin.Engine) {
e.POST("/v1/chat/completions", s.completions)
e.POST("/v1/engines/copilot-codex/completions", s.codeCompletions)
}
func (s *ProxyService) completions(c *gin.Context) {
ctx := c.Request.Context()
body, err := io.ReadAll(c.Request.Body)
if nil != err {
c.AbortWithStatus(http.StatusBadRequest)
return
}
model := gjson.GetBytes(body, "model").String()
if mapped, ok := s.cfg.ChatModelMap[model]; ok {
model = mapped
} else {
model = s.cfg.ChatModelDefault
}
body, _ = sjson.SetBytes(body, "model", model)
body, _ = sjson.DeleteBytes(body, "intent")
proxyUrl := s.cfg.ChatApiBase + "/chat/completions"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body)))
if nil != err {
c.AbortWithStatus(http.StatusInternalServerError)
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+s.cfg.ChatApiKey)
if "" != s.cfg.ChatApiOrganization {
req.Header.Set("OpenAI-Organization", s.cfg.ChatApiOrganization)
}
if "" != s.cfg.ChatApiProject {
req.Header.Set("OpenAI-Project", s.cfg.ChatApiProject)
}
resp, err := s.client.Do(req)
if nil != err {
if errors.Is(err, context.Canceled) {
c.AbortWithStatus(http.StatusRequestTimeout)
return
}
log.Println("request conversation failed:", err.Error())
c.AbortWithStatus(http.StatusInternalServerError)
return
}
defer closeIO(resp.Body)
if resp.StatusCode != http.StatusOK { // log
body, _ := io.ReadAll(resp.Body)
log.Println("request completions failed:", string(body))
resp.Body = io.NopCloser(bytes.NewBuffer(body))
}
c.Status(resp.StatusCode)
contentType := resp.Header.Get("Content-Type")
if "" != contentType {
c.Header("Content-Type", contentType)
}
_, _ = io.Copy(c.Writer, resp.Body)
}
func (s *ProxyService) codeCompletions(c *gin.Context) {
ctx := c.Request.Context()
time.Sleep(100 * time.Millisecond)
if ctx.Err() != nil {
abortCodex(c, http.StatusRequestTimeout)
return
}
body, err := io.ReadAll(c.Request.Body)
if nil != err {
abortCodex(c, http.StatusBadRequest)
return
}
prompt := gjson.GetBytes(body, "prompt").String()
suffix := gjson.GetBytes(body, "suffix").String()
inputTokens := len(s.tokenizer.Encode(prompt, nil, nil))
suffixTokens := len(s.tokenizer.Encode(suffix, nil, nil))
outputTokens := int(gjson.GetBytes(body, "max_tokens").Int())
totalTokens := inputTokens + suffixTokens + outputTokens
if totalTokens > s.cfg.CodexMaxTokens { // reduce
left, right := 0, len(prompt)
for left < right {
mid := (left + right) / 2
subPrompt := prompt[mid:]
subInputTokens := len(s.tokenizer.Encode(subPrompt, nil, nil))
totalTokens = subInputTokens + suffixTokens + outputTokens
if totalTokens > s.cfg.CodexMaxTokens {
left = mid + 1
} else {
right = mid
}
}
body, _ = sjson.SetBytes(body, "prompt", prompt[left:])
}
body, _ = sjson.DeleteBytes(body, "extra")
body, _ = sjson.DeleteBytes(body, "nwo")
var model string
proxyUrl := s.cfg.CodexApiBase
if s.cfg.CodexModelDefault == "" || s.cfg.CodexModelDefault == INSTRUCT_MODEL {
model = INSTRUCT_MODEL
proxyUrl = proxyUrl + "/completions"
} else {
model = s.cfg.CodexModelDefault
}
body, _ = sjson.SetBytes(body, "model", model)
if model == "deepseek-coder" {
message := gjson.GetBytes(body, "prompt").String()
body, _ = sjson.DeleteBytes(body, "prompt")
msg := make([]GPTMessage, 0)
msg = append(msg, GPTMessage{Role: "system", Content: "You are a helpful assistant"})
msg = append(msg, GPTMessage{Role: "user", Content: message})
body, _ = sjson.SetBytes(body, "messages", msg)
body, _ = sjson.DeleteBytes(body, "n")
} else if strings.HasPrefix(model, "@") {
proxyUrl = s.cfg.CodexApiBase
message := gjson.GetBytes(body, "prompt").String()
body, _ = sjson.DeleteBytes(body, "prompt")
msg := make([]GPTMessage, 0)
msg = append(msg, GPTMessage{Role: "system", Content: ""})
msg = append(msg, GPTMessage{Role: "user", Content: message})
body, _ = sjson.SetBytes(body, "messages", msg)
body, _ = sjson.DeleteBytes(body, "n")
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body)))
if nil != err {
abortCodex(c, http.StatusInternalServerError)
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+s.cfg.CodexApiKey)
if "" != s.cfg.CodexApiOrganization {
req.Header.Set("OpenAI-Organization", s.cfg.CodexApiOrganization)
}
if "" != s.cfg.CodexApiProject {
req.Header.Set("OpenAI-Project", s.cfg.CodexApiProject)
}
resp, err := s.client.Do(req)
if nil != err {
if errors.Is(err, context.Canceled) {
abortCodex(c, http.StatusRequestTimeout)
return
}
log.Println("request completions failed:", err.Error())
abortCodex(c, http.StatusInternalServerError)
return
}
defer closeIO(resp.Body)
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
log.Println("request completions failed:", string(body))
abortCodex(c, resp.StatusCode)
return
}
c.Status(resp.StatusCode)
contentType := resp.Header.Get("Content-Type")
if "" != contentType {
c.Header("Content-Type", contentType)
}
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.IndexByte(data, '\n'); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < len("data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
dataChan <- data
}
stopChan <- true
}()
SetEventStreamHeaders(c)
id := GetResponseID(c)
responseModel := c.GetString("original_model")
var responseText string
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var codeResponse StreamResponse
err := json.Unmarshal([]byte(data), &codeResponse)
if err != nil {
if data == "[DONE]" {
return true
}
log.Println("error unmarshalling stream response: ", err.Error())
return true
}
if model != INSTRUCT_MODEL {
response := StreamResponseCloudflare2OpenAI(&codeResponse)
if response == nil {
return true
}
responseText += codeResponse.Response
response.Id = id
response.Model = responseModel
jsonStr, err := json.Marshal(response)
if err != nil {
log.Println("error marshalling stream response: ", err.Error())
return true
}
c.Render(-1, CustomEvent{Data: "data: " + string(jsonStr)})
} else {
c.Render(-1, CustomEvent{Data: "data:" + string(data)})
}
return true
case <-stopChan:
c.Render(-1, CustomEvent{Data: "data: [DONE]"})
return false
}
})
_ = resp.Body.Close()
}
func main() {
cfg := readConfig()
gin.SetMode(gin.ReleaseMode)
r := gin.Default()
proxyService, err := NewProxyService(cfg)
if nil != err {
log.Fatal(err)
return
}
proxyService.InitRoutes(r)
err = r.Run(cfg.Bind)
if nil != err {
log.Fatal(err)
return
}
}