diff --git a/server/internal/code.go b/server/internal/code.go deleted file mode 100644 index 102d49a9e5..0000000000 --- a/server/internal/code.go +++ /dev/null @@ -1,29 +0,0 @@ -package internal - -type Code struct { - code string - msg string -} - -var ( - codeOk = NewCode("0", "ok") - codeSuccess = NewCode("0", "success") - - codeErrParamsInvalid = NewCode("10000", "params invalid") - codeErrWorkersLimit = NewCode("10001", "workers limit") - codeErrChannelNotExisted = NewCode("10002", "channel not existed") - codeErrChannelExisted = NewCode("10003", "channel existed") - codeErrChannelEmpty = NewCode("10004", "channel empty") - codeErrGenerateTokenFailed = NewCode("10005", "generate token failed") - - codeErrProcessManifestFailed = NewCode("10100", "process manifest json failed") - codeErrStartWorkerFailed = NewCode("10101", "start worker failed") - codeErrStopAppFailed = NewCode("10102", "stop worker failed") -) - -func NewCode(code string, msg string) *Code { - return &Code{ - code: code, - msg: msg, - } -} diff --git a/server/internal/http_server.go b/server/internal/http_server.go index cf49a94f0f..e35fb94a98 100644 --- a/server/internal/http_server.go +++ b/server/internal/http_server.go @@ -11,360 +11,54 @@ package internal import ( - "fmt" + "app/internal/router" + "app/internal/service" + "context" "log/slog" "net/http" - "os" - "strings" - "time" - rtctokenbuilder "github.com/AgoraIO/Tools/DynamicKey/AgoraDynamicKey/go/src/rtctokenbuilder2" "github.com/gin-gonic/gin" - "github.com/gin-gonic/gin/binding" - "github.com/gogf/gf/crypto/gmd5" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" ) type HttpServer struct { - config *HttpServerConfig + deps HttpServerDepends + server *http.Server } -type HttpServerConfig struct { - AppId string - AppCertificate string - ManifestJsonFile string - Port string - TTSVendorChinese string - TTSVendorEnglish string - WorkersMax int - WorkerQuitTimeoutSeconds int -} - -type PingReq struct { - RequestId string `form:"request_id,omitempty" json:"request_id,omitempty"` - ChannelName string `form:"channel_name,omitempty" json:"channel_name,omitempty"` -} - -type StartReq struct { - RequestId string `form:"request_id,omitempty" json:"request_id,omitempty"` - AgoraAsrLanguage string `form:"agora_asr_language,omitempty" json:"agora_asr_language,omitempty"` - ChannelName string `form:"channel_name,omitempty" json:"channel_name,omitempty"` - RemoteStreamId uint32 `form:"remote_stream_id,omitempty" json:"remote_stream_id,omitempty"` - VoiceType string `form:"voice_type,omitempty" json:"voice_type,omitempty"` -} - -type StopReq struct { - RequestId string `form:"request_id,omitempty" json:"request_id,omitempty"` - ChannelName string `form:"channel_name,omitempty" json:"channel_name,omitempty"` +type HttpServerDepends struct { + Config HttpServerConfig + MainSvc *service.MainService } -type GenerateTokenReq struct { - RequestId string `form:"request_id,omitempty" json:"request_id,omitempty"` - ChannelName string `form:"channel_name,omitempty" json:"channel_name,omitempty"` - Uid uint32 `form:"uid,omitempty" json:"uid,omitempty"` +type HttpServerConfig struct { + Address string } -const ( - privilegeExpirationInSeconds = uint32(86400) - tokenExpirationInSeconds = uint32(86400) - - languageChinese = "zh-CN" - languageEnglish = "en-US" - - ManifestJsonFile = "./agents/manifest.json" - ManifestJsonFileElevenlabs = "./agents/manifest.elevenlabs.json" - - TTSVendorAzure = "azure" - TTSVendorElevenlabs = "elevenlabs" - - voiceTypeMale = "male" - voiceTypeFemale = "female" -) - var ( - voiceNameMap = map[string]map[string]map[string]string{ - languageChinese: { - TTSVendorAzure: { - voiceTypeMale: "zh-CN-YunxiNeural", - voiceTypeFemale: "zh-CN-XiaoxiaoNeural", - }, - TTSVendorElevenlabs: { - voiceTypeMale: "pNInz6obpgDQGcFmaJgB", // Adam - voiceTypeFemale: "Xb7hH8MSUJpSbSDYk0k2", // Alice - }, - }, - languageEnglish: { - TTSVendorAzure: { - voiceTypeMale: "en-US-BrianNeural", - voiceTypeFemale: "en-US-JaneNeural", - }, - TTSVendorElevenlabs: { - voiceTypeMale: "pNInz6obpgDQGcFmaJgB", // Adam - voiceTypeFemale: "Xb7hH8MSUJpSbSDYk0k2", // Alice - }, - }, - } - logTag = slog.String("service", "HTTP_SERVER") ) -func NewHttpServer(httpServerConfig *HttpServerConfig) *HttpServer { - return &HttpServer{ - config: httpServerConfig, - } -} - -func (s *HttpServer) getManifestJsonFile(language string) (manifestJsonFile string) { - ttsVendor := s.getTtsVendor(language) - manifestJsonFile = ManifestJsonFile - - if ttsVendor == TTSVendorElevenlabs { - manifestJsonFile = ManifestJsonFileElevenlabs - } - - return -} - -func (s *HttpServer) getTtsVendor(language string) string { - if language == languageChinese { - return s.config.TTSVendorChinese - } - - return s.config.TTSVendorEnglish -} - -func (s *HttpServer) handlerHealth(c *gin.Context) { - slog.Debug("handlerHealth", logTag) - s.output(c, codeOk, nil) -} - -func (s *HttpServer) handlerPing(c *gin.Context) { - var req PingReq - - if err := c.ShouldBindBodyWith(&req, binding.JSON); err != nil { - slog.Error("handlerPing params invalid", "err", err, logTag) - s.output(c, codeErrParamsInvalid, http.StatusBadRequest) - return - } - - slog.Info("handlerPing start", "channelName", req.ChannelName, "requestId", req.RequestId, logTag) - - if strings.TrimSpace(req.ChannelName) == "" { - slog.Error("handlerPing channel empty", "channelName", req.ChannelName, "requestId", req.RequestId, logTag) - s.output(c, codeErrChannelEmpty, http.StatusBadRequest) - return - } - - if !workers.Contains(req.ChannelName) { - slog.Error("handlerPing channel not existed", "channelName", req.ChannelName, "requestId", req.RequestId, logTag) - s.output(c, codeErrChannelNotExisted, http.StatusBadRequest) - return - } - - // Update worker - worker := workers.Get(req.ChannelName).(*Worker) - worker.UpdateTs = time.Now().Unix() - - slog.Info("handlerPing end", "worker", worker, "requestId", req.RequestId, logTag) - s.output(c, codeSuccess, nil) -} - -func (s *HttpServer) handlerStart(c *gin.Context) { - workersRunning := workers.Size() - - slog.Info("handlerStart start", "workersRunning", workersRunning, logTag) - - var req StartReq - if err := c.ShouldBindBodyWith(&req, binding.JSON); err != nil { - slog.Error("handlerStart params invalid", "err", err, "requestId", req.RequestId, logTag) - s.output(c, codeErrParamsInvalid, http.StatusBadRequest) - return - } - - if strings.TrimSpace(req.ChannelName) == "" { - slog.Error("handlerStart channel empty", "channelName", req.ChannelName, "requestId", req.RequestId, logTag) - s.output(c, codeErrChannelEmpty, http.StatusBadRequest) - return - } - - if workersRunning >= s.config.WorkersMax { - slog.Error("handlerStart workers exceed", "workersRunning", workersRunning, "WorkersMax", s.config.WorkersMax, "requestId", req.RequestId, logTag) - s.output(c, codeErrWorkersLimit, http.StatusTooManyRequests) - return - } - - if workers.Contains(req.ChannelName) { - slog.Error("handlerStart channel existed", "channelName", req.ChannelName, "requestId", req.RequestId, logTag) - s.output(c, codeErrChannelExisted, http.StatusBadRequest) - return - } - - manifestJsonFile, logFile, err := s.processManifest(&req) - if err != nil { - slog.Error("handlerStart process manifest", "channelName", req.ChannelName, "requestId", req.RequestId, logTag) - s.output(c, codeErrProcessManifestFailed, http.StatusInternalServerError) - return - } - - worker := newWorker(req.ChannelName, logFile, manifestJsonFile) - worker.QuitTimeoutSeconds = s.config.WorkerQuitTimeoutSeconds - if err := worker.start(&req); err != nil { - slog.Error("handlerStart start worker failed", "err", err, "requestId", req.RequestId, logTag) - s.output(c, codeErrStartWorkerFailed, http.StatusInternalServerError) - return - } - workers.SetIfNotExist(req.ChannelName, worker) - - slog.Info("handlerStart end", "workersRunning", workers.Size(), "worker", worker, "requestId", req.RequestId, logTag) - s.output(c, codeSuccess, nil) -} - -func (s *HttpServer) handlerStop(c *gin.Context) { - var req StopReq - - if err := c.ShouldBindBodyWith(&req, binding.JSON); err != nil { - slog.Error("handlerStop params invalid", "err", err, logTag) - s.output(c, codeErrParamsInvalid, http.StatusBadRequest) - return - } - - slog.Info("handlerStop start", "req", req, logTag) - - if strings.TrimSpace(req.ChannelName) == "" { - slog.Error("handlerStop channel empty", "channelName", req.ChannelName, "requestId", req.RequestId, logTag) - s.output(c, codeErrChannelEmpty, http.StatusBadRequest) - return - } - - if !workers.Contains(req.ChannelName) { - slog.Error("handlerStop channel not existed", "channelName", req.ChannelName, "requestId", req.RequestId, logTag) - s.output(c, codeErrChannelNotExisted, http.StatusBadRequest) - return - } - - worker := workers.Get(req.ChannelName).(*Worker) - if err := worker.stop(req.RequestId, req.ChannelName); err != nil { - slog.Error("handlerStop kill app failed", "err", err, "worker", workers.Get(req.ChannelName), "requestId", req.RequestId, logTag) - s.output(c, codeErrStopAppFailed, http.StatusInternalServerError) - return - } - - slog.Info("handlerStop end", "requestId", req.RequestId, logTag) - s.output(c, codeSuccess, nil) -} - -func (s *HttpServer) handlerGenerateToken(c *gin.Context) { - var req GenerateTokenReq - - if err := c.ShouldBindBodyWith(&req, binding.JSON); err != nil { - slog.Error("handlerGenerateToken params invalid", "err", err, logTag) - s.output(c, codeErrParamsInvalid, http.StatusBadRequest) - return - } - - slog.Info("handlerGenerateToken start", "req", req, logTag) - - if strings.TrimSpace(req.ChannelName) == "" { - slog.Error("handlerGenerateToken channel empty", "channelName", req.ChannelName, "requestId", req.RequestId, logTag) - s.output(c, codeErrChannelEmpty, http.StatusBadRequest) - return - } - - if s.config.AppCertificate == "" { - s.output(c, codeSuccess, map[string]any{"appId": s.config.AppId, "token": s.config.AppId, "channel_name": req.ChannelName, "uid": req.Uid}) - return - } - - token, err := rtctokenbuilder.BuildTokenWithUid(s.config.AppId, s.config.AppCertificate, req.ChannelName, req.Uid, rtctokenbuilder.RolePublisher, tokenExpirationInSeconds, privilegeExpirationInSeconds) - if err != nil { - slog.Error("handlerGenerateToken generate token failed", "err", err, "requestId", req.RequestId, logTag) - s.output(c, codeErrGenerateTokenFailed, http.StatusBadRequest) - return - } +func NewHttpServer(deps HttpServerDepends) *HttpServer { + r := gin.Default() + r.Use(corsMiddleware()) - slog.Info("handlerGenerateToken end", "requestId", req.RequestId, logTag) - s.output(c, codeSuccess, map[string]any{"appId": s.config.AppId, "token": token, "channel_name": req.ChannelName, "uid": req.Uid}) -} + router.Apply(r, deps.MainSvc) -func (s *HttpServer) output(c *gin.Context, code *Code, data any, httpStatus ...int) { - if len(httpStatus) == 0 { - httpStatus = append(httpStatus, http.StatusOK) + return &HttpServer{ + deps: deps, + server: &http.Server{ + Addr: deps.Config.Address, + Handler: r, + }, } - - c.JSON(httpStatus[0], gin.H{"code": code.code, "msg": code.msg, "data": data}) } -func (s *HttpServer) processManifest(req *StartReq) (manifestJsonFile string, logFile string, err error) { - manifestJsonFile = s.getManifestJsonFile(req.AgoraAsrLanguage) - content, err := os.ReadFile(manifestJsonFile) - if err != nil { - slog.Error("handlerStart read manifest.json failed", "err", err, "manifestJsonFile", manifestJsonFile, "requestId", req.RequestId, logTag) - return - } - - manifestJson := string(content) - - if s.config.AppId != "" { - manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="agora_rtc").property.app_id`, s.config.AppId) - } - appId := gjson.Get(manifestJson, `predefined_graphs.0.nodes.#(name=="agora_rtc").property.app_id`).String() - - // Generate token - token := appId - if s.config.AppCertificate != "" { - token, err = rtctokenbuilder.BuildTokenWithUid(appId, s.config.AppCertificate, req.ChannelName, 0, rtctokenbuilder.RoleSubscriber, tokenExpirationInSeconds, privilegeExpirationInSeconds) - if err != nil { - slog.Error("handlerStart generate token failed", "err", err, "requestId", req.RequestId, logTag) - return - } - } - - manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="agora_rtc").property.token`, token) - if req.AgoraAsrLanguage != "" { - manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="agora_rtc").property.agora_asr_language`, req.AgoraAsrLanguage) - } - if req.ChannelName != "" { - manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="agora_rtc").property.channel`, req.ChannelName) - } - if req.RemoteStreamId != 0 { - manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="agora_rtc").property.remote_stream_id`, req.RemoteStreamId) - } - - language := gjson.Get(manifestJson, `predefined_graphs.0.nodes.#(name=="agora_rtc").property.agora_asr_language`).String() - - ttsVendor := s.getTtsVendor(language) - voiceName := voiceNameMap[language][ttsVendor][req.VoiceType] - if voiceName != "" { - if ttsVendor == TTSVendorAzure { - manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="azure_tts").property.azure_synthesis_voice_name`, voiceName) - } else if ttsVendor == TTSVendorElevenlabs { - manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="elevenlabs_tts").property.voice_id`, voiceName) - } - } - - channelNameMd5 := gmd5.MustEncryptString(req.ChannelName) - ts := time.Now().UnixNano() - manifestJsonFile = fmt.Sprintf("/tmp/manifest-%s-%d.json", channelNameMd5, ts) - logFile = fmt.Sprintf("/tmp/app-%s-%d.log", channelNameMd5, ts) - os.WriteFile(manifestJsonFile, []byte(manifestJson), 0644) - - return +func (s *HttpServer) Run() error { + slog.Info("server start", "address", s.server.Addr, logTag) + go s.deps.MainSvc.CleanWorker() + return s.server.ListenAndServe() } -func (s *HttpServer) Start() { - r := gin.Default() - r.Use(corsMiddleware()) - - r.GET("/", s.handlerHealth) - r.GET("/health", s.handlerHealth) - r.POST("/ping", s.handlerPing) - r.POST("/start", s.handlerStart) - r.POST("/stop", s.handlerStop) - r.POST("/token/generate", s.handlerGenerateToken) - - slog.Info("server start", "port", s.config.Port, logTag) - - go cleanWorker() - r.Run(s.config.Port) +func (s *HttpServer) Shutdown(ctx context.Context) error { + return s.server.Shutdown(ctx) } diff --git a/server/internal/provider/manifest.go b/server/internal/provider/manifest.go new file mode 100644 index 0000000000..63a36439eb --- /dev/null +++ b/server/internal/provider/manifest.go @@ -0,0 +1,117 @@ +package provider + +import ( + "log/slog" + "os" + "path/filepath" + "regexp" + + "github.com/tidwall/sjson" +) + +type ManifestProvider struct { + // manifestJsons + // key: fileName `manifest.json` `manifest.elevenlabs.json` + // value: manifestJson + manifestJsons map[string]string +} + +func NewManifestProvider() *ManifestProvider { + return &ManifestProvider{ + manifestJsons: make(map[string]string), + } +} + +func (p *ManifestProvider) LoadManifest(manifestJsonDir string) error { + files, err := os.ReadDir(manifestJsonDir) + if err != nil { + slog.Error("read manifestJsonDir failed", "err", err, "manifestJsonDir", manifestJsonDir) + return err + } + + matcher := regexp.MustCompile(`^manifest(\..+)?\.json$`) + for _, file := range files { + if file.IsDir() { + continue + } + if !matcher.MatchString(file.Name()) { + continue + } + + filePath := filepath.Join(manifestJsonDir, file.Name()) + content, err := os.ReadFile(filePath) + if err != nil { + slog.Error("read manifest.json failed", "err", err, "filePath", filePath) + return err + } + + manifestJson := string(content) + manifestJson = p.injectEnvVar(manifestJson) + + p.manifestJsons[file.Name()] = manifestJson + } + + return nil +} + +func (p *ManifestProvider) injectEnvVar(manifestJson string) string { + appId := os.Getenv("AGORA_APP_ID") + if appId != "" { + manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="agora_rtc").property.app_id`, appId) + } + + azureSttKey := os.Getenv("AZURE_STT_KEY") + if azureSttKey != "" { + manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="agora_rtc").property.agora_asr_vendor_key`, azureSttKey) + } + + azureSttRegion := os.Getenv("AZURE_STT_REGION") + if azureSttRegion != "" { + manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="agora_rtc").property.agora_asr_vendor_region`, azureSttRegion) + } + + openaiBaseUrl := os.Getenv("OPENAI_BASE_URL") + if openaiBaseUrl != "" { + manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="openai_chatgpt").property.base_url`, openaiBaseUrl) + } + + openaiApiKey := os.Getenv("OPENAI_API_KEY") + if openaiApiKey != "" { + manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="openai_chatgpt").property.api_key`, openaiApiKey) + } + + openaiModel := os.Getenv("OPENAI_MODEL") + if openaiModel != "" { + manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="openai_chatgpt").property.model`, openaiModel) + } + + proxyUrl := os.Getenv("PROXY_URL") + if proxyUrl != "" { + manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="openai_chatgpt").property.proxy_url`, proxyUrl) + } + + azureTtsKey := os.Getenv("AZURE_TTS_KEY") + if azureTtsKey != "" { + manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="azure_tts").property.azure_subscription_key`, azureTtsKey) + } + + azureTtsRegion := os.Getenv("AZURE_TTS_REGION") + if azureTtsRegion != "" { + manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="azure_tts").property.azure_subscription_region`, azureTtsRegion) + } + + elevenlabsTtsKey := os.Getenv("ELEVENLABS_TTS_KEY") + if elevenlabsTtsKey != "" { + manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="elevenlabs_tts").property.api_key`, elevenlabsTtsKey) + } + + return manifestJson +} + +func (p *ManifestProvider) GetManifestJson(vendor string) (string, bool) { + if len(vendor) > 0 { + vendor = "." + vendor + } + manifestJson, ok := p.manifestJsons["manifest"+vendor+".json"] + return manifestJson, ok +} diff --git a/server/internal/router/router.go b/server/internal/router/router.go new file mode 100644 index 0000000000..02dc67b5b6 --- /dev/null +++ b/server/internal/router/router.go @@ -0,0 +1,16 @@ +package router + +import ( + "app/internal/service" + + "github.com/gin-gonic/gin" +) + +func Apply(r gin.IRouter, mainSvc *service.MainService) { + r.GET("/", mainSvc.HandlerHealth) + r.GET("/health", mainSvc.HandlerHealth) + r.POST("/ping", mainSvc.HandlerPing) + r.POST("/start", mainSvc.HandlerStart) + r.POST("/stop", mainSvc.HandlerStop) + r.POST("/token/generate", mainSvc.HandlerGenerateToken) +} diff --git a/server/internal/service/service.go b/server/internal/service/service.go new file mode 100644 index 0000000000..f48d5bf054 --- /dev/null +++ b/server/internal/service/service.go @@ -0,0 +1,314 @@ +package service + +import ( + "app/internal/provider" + "app/pkg/common" + pkgProvider "app/pkg/provider" + "errors" + "fmt" + "log/slog" + "net/http" + "os" + "strings" + "time" + + rtctokenbuilder "github.com/AgoraIO/Tools/DynamicKey/AgoraDynamicKey/go/src/rtctokenbuilder2" + "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin/binding" + "github.com/gogf/gf/container/gmap" + "github.com/gogf/gf/crypto/gmd5" + "github.com/google/uuid" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +const ( + privilegeExpirationInSeconds = uint32(86400) + tokenExpirationInSeconds = uint32(86400) +) + +var ( + logTag = slog.String("service", "MAIN_SERVICE") +) + +type MainService struct { + deps MainServiceDepends + workers *gmap.Map +} + +type MainServiceDepends struct { + Config MainServiceConfig + ManifestProvider *provider.ManifestProvider +} + +type MainServiceConfig struct { + AppId string + AppCertificate string + TTSVendorChinese string + TTSVendorEnglish string + WorkersMax int + WorkerQuitTimeoutSeconds int +} + +func NewMainService(deps MainServiceDepends) *MainService { + return &MainService{ + deps: deps, + workers: gmap.New(true), + } +} + +func (s *MainService) output(c *gin.Context, code *common.Code, data any, httpStatus ...int) { + if len(httpStatus) == 0 { + httpStatus = append(httpStatus, http.StatusOK) + } + + c.JSON(httpStatus[0], gin.H{"code": code.Code, "msg": code.Msg, "data": data}) +} + +func (s *MainService) HandlerHealth(c *gin.Context) { + slog.Debug("handlerHealth", logTag) + s.output(c, common.CodeOk, nil) +} + +func (s *MainService) HandlerPing(c *gin.Context) { + var req common.PingReq + + if err := c.ShouldBindBodyWith(&req, binding.JSON); err != nil { + slog.Error("handlerPing params invalid", "err", err, logTag) + s.output(c, common.CodeErrParamsInvalid, http.StatusBadRequest) + return + } + + slog.Info("handlerPing start", "channelName", req.ChannelName, "requestId", req.RequestId, logTag) + + if strings.TrimSpace(req.ChannelName) == "" { + slog.Error("handlerPing channel empty", "channelName", req.ChannelName, "requestId", req.RequestId, logTag) + s.output(c, common.CodeErrChannelEmpty, http.StatusBadRequest) + return + } + + if !s.workers.Contains(req.ChannelName) { + slog.Error("handlerPing channel not existed", "channelName", req.ChannelName, "requestId", req.RequestId, logTag) + s.output(c, common.CodeErrChannelNotExisted, http.StatusBadRequest) + return + } + + // Update worker + worker := s.workers.Get(req.ChannelName).(*Worker) + worker.UpdateTs = time.Now().Unix() + + slog.Info("handlerPing end", "worker", worker, "requestId", req.RequestId, logTag) + s.output(c, common.CodeSuccess, nil) +} + +// HandlerStart is a handle for start worker. +func (s *MainService) HandlerStart(c *gin.Context) { + workersRunning := s.workers.Size() + + slog.Info("handlerStart start", "workersRunning", workersRunning, logTag) + + var req common.StartReq + if err := c.ShouldBindBodyWith(&req, binding.JSON); err != nil { + slog.Error("handlerStart params invalid", "err", err, "requestId", req.RequestId, logTag) + s.output(c, common.CodeErrParamsInvalid, http.StatusBadRequest) + return + } + + if strings.TrimSpace(req.ChannelName) == "" { + slog.Error("handlerStart channel empty", "channelName", req.ChannelName, "requestId", req.RequestId, logTag) + s.output(c, common.CodeErrChannelEmpty, http.StatusBadRequest) + return + } + + if workersRunning >= s.deps.Config.WorkersMax { + slog.Error("handlerStart workers exceed", "workersRunning", workersRunning, "WorkersMax", s.deps.Config.WorkersMax, "requestId", req.RequestId, logTag) + s.output(c, common.CodeErrWorkersLimit, http.StatusTooManyRequests) + return + } + + if s.workers.Contains(req.ChannelName) { + slog.Error("handlerStart channel existed", "channelName", req.ChannelName, "requestId", req.RequestId, logTag) + s.output(c, common.CodeErrChannelExisted, http.StatusBadRequest) + return + } + + manifestJsonFile, logFile, err := s.createWorkerManifest(&req) + if err != nil { + slog.Error("handlerStart create worker manifest", "channelName", req.ChannelName, "requestId", req.RequestId, logTag) + s.output(c, common.CodeErrProcessManifestFailed, http.StatusInternalServerError) + return + } + + worker := newWorker(req.ChannelName, logFile, manifestJsonFile) + worker.QuitTimeoutSeconds = s.deps.Config.WorkerQuitTimeoutSeconds + if err := worker.start(&req); err != nil { + slog.Error("handlerStart start worker failed", "err", err, "requestId", req.RequestId, logTag) + s.output(c, common.CodeErrStartWorkerFailed, http.StatusInternalServerError) + return + } + s.workers.SetIfNotExist(req.ChannelName, worker) + + slog.Info("handlerStart end", "workersRunning", s.workers.Size(), "worker", worker, "requestId", req.RequestId, logTag) + s.output(c, common.CodeSuccess, nil) +} + +func (s *MainService) HandlerStop(c *gin.Context) { + var req common.StopReq + + if err := c.ShouldBindBodyWith(&req, binding.JSON); err != nil { + slog.Error("handlerStop params invalid", "err", err, logTag) + s.output(c, common.CodeErrParamsInvalid, http.StatusBadRequest) + return + } + + slog.Info("handlerStop start", "req", req, logTag) + + if strings.TrimSpace(req.ChannelName) == "" { + slog.Error("handlerStop channel empty", "channelName", req.ChannelName, "requestId", req.RequestId, logTag) + s.output(c, common.CodeErrChannelEmpty, http.StatusBadRequest) + return + } + + if !s.workers.Contains(req.ChannelName) { + slog.Error("handlerStop channel not existed", "channelName", req.ChannelName, "requestId", req.RequestId, logTag) + s.output(c, common.CodeErrChannelNotExisted, http.StatusBadRequest) + return + } + + worker := s.workers.Get(req.ChannelName).(*Worker) + if err := worker.stop(req.RequestId, req.ChannelName); err != nil { + slog.Error("handlerStop kill app failed", "err", err, "worker", s.workers.Get(req.ChannelName), "requestId", req.RequestId, logTag) + s.output(c, common.CodeErrStopAppFailed, http.StatusInternalServerError) + return + } + s.workers.Remove(req.ChannelName) + + slog.Info("handlerStop end", "requestId", req.RequestId, logTag) + s.output(c, common.CodeSuccess, nil) +} + +func (s *MainService) HandlerGenerateToken(c *gin.Context) { + var req common.GenerateTokenReq + + if err := c.ShouldBindBodyWith(&req, binding.JSON); err != nil { + slog.Error("handlerGenerateToken params invalid", "err", err, logTag) + s.output(c, common.CodeErrParamsInvalid, http.StatusBadRequest) + return + } + + slog.Info("handlerGenerateToken start", "req", req, logTag) + + if strings.TrimSpace(req.ChannelName) == "" { + slog.Error("handlerGenerateToken channel empty", "channelName", req.ChannelName, "requestId", req.RequestId, logTag) + s.output(c, common.CodeErrChannelEmpty, http.StatusBadRequest) + return + } + + if s.deps.Config.AppCertificate == "" { + s.output(c, common.CodeSuccess, map[string]any{"appId": s.deps.Config.AppId, "token": s.deps.Config.AppId, "channel_name": req.ChannelName, "uid": req.Uid}) + return + } + + token, err := rtctokenbuilder.BuildTokenWithUid(s.deps.Config.AppId, s.deps.Config.AppCertificate, req.ChannelName, req.Uid, rtctokenbuilder.RolePublisher, tokenExpirationInSeconds, privilegeExpirationInSeconds) + if err != nil { + slog.Error("handlerGenerateToken generate token failed", "err", err, "requestId", req.RequestId, logTag) + s.output(c, common.CodeErrGenerateTokenFailed, http.StatusBadRequest) + return + } + + slog.Info("handlerGenerateToken end", "requestId", req.RequestId, logTag) + s.output(c, common.CodeSuccess, map[string]any{"appId": s.deps.Config.AppId, "token": token, "channel_name": req.ChannelName, "uid": req.Uid}) +} + +// createWorkerManifest create worker temporary Mainfest. +func (s *MainService) createWorkerManifest(req *common.StartReq) (manifestJsonFile string, logFile string, err error) { + vendor := s.getTtsVendor(req.AgoraAsrLanguage) + tts := pkgProvider.GetTts(vendor) + if tts == nil { + err = errors.New(fmt.Sprintf("unknow tts vendor", vendor)) + slog.Error("handlerStart generate token failed", "err", err, "requestId", req.RequestId, logTag) + return "", "", err + } + + manifestJson, ok := s.deps.ManifestProvider.GetManifestJson(vendor) + if !ok { + err = errors.New(fmt.Sprintf("unknow manifest vendor", vendor)) + slog.Error("handlerStart get manifest json failed", "err", err, "requestId", req.RequestId, logTag) + return "", "", err + } + + if s.deps.Config.AppId != "" { + manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="agora_rtc").property.app_id`, s.deps.Config.AppId) + } + appId := gjson.Get(manifestJson, `predefined_graphs.0.nodes.#(name=="agora_rtc").property.app_id`).String() + + // Generate token + token := appId + if s.deps.Config.AppCertificate != "" { + token, err = rtctokenbuilder.BuildTokenWithUid(appId, s.deps.Config.AppCertificate, req.ChannelName, 0, rtctokenbuilder.RoleSubscriber, tokenExpirationInSeconds, privilegeExpirationInSeconds) + if err != nil { + slog.Error("handlerStart generate token failed", "err", err, "requestId", req.RequestId, logTag) + return "", "", err + } + } + + manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="agora_rtc").property.token`, token) + if req.AgoraAsrLanguage != "" { + manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="agora_rtc").property.agora_asr_language`, req.AgoraAsrLanguage) + } + if req.ChannelName != "" { + manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="agora_rtc").property.channel`, req.ChannelName) + } + if req.RemoteStreamId != 0 { + manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="agora_rtc").property.remote_stream_id`, req.RemoteStreamId) + } + + language := gjson.Get(manifestJson, `predefined_graphs.0.nodes.#(name=="agora_rtc").property.agora_asr_language`).String() + manifestJson, err = tts.ProcessManifest(manifestJson, common.Language(language), req.VoiceType) + if err != nil { + slog.Error("handlerStart tts ProcessManifest failed", "err", err, "requestId", req.RequestId, logTag) + return "", "", err + } + + channelNameMd5 := gmd5.MustEncryptString(req.ChannelName) + ts := time.Now().UnixNano() + manifestJsonFile = fmt.Sprintf("/tmp/manifest-%s-%d.json", channelNameMd5, ts) + logFile = fmt.Sprintf("/tmp/app-%s-%d.log", channelNameMd5, ts) + err = os.WriteFile(manifestJsonFile, []byte(manifestJson), 0644) + if err != nil { + slog.Error("handlerStart write manifest.json failed", "err", err, "manifestJsonFile", manifestJsonFile, "requestId", req.RequestId, logTag) + return "", "", err + } + + return manifestJsonFile, logFile, nil +} + +// CleanWorker clean unused workers in background. +func (s *MainService) CleanWorker() { + for { + for _, channelName := range s.workers.Keys() { + worker := s.workers.Get(channelName).(*Worker) + + nowTs := time.Now().Unix() + if worker.UpdateTs+int64(worker.QuitTimeoutSeconds) < nowTs { + if err := worker.stop(uuid.New().String(), channelName.(string)); err != nil { + slog.Error("Worker cleanWorker failed", "err", err, "channelName", channelName, logTag) + continue + } + + slog.Info("Worker cleanWorker success", "channelName", channelName, "worker", worker, "nowTs", nowTs, logTag) + } + } + + slog.Debug("Worker cleanWorker sleep", "sleep", workerCleanSleepSeconds, logTag) + time.Sleep(workerCleanSleepSeconds * time.Second) + } +} + +func (s *MainService) getTtsVendor(language common.Language) string { + if language == common.LanguageChinese { + return s.deps.Config.TTSVendorChinese + } + + return s.deps.Config.TTSVendorEnglish +} diff --git a/server/internal/worker.go b/server/internal/service/worker.go similarity index 57% rename from server/internal/worker.go rename to server/internal/service/worker.go index e3daf80a83..b1546fdba1 100644 --- a/server/internal/worker.go +++ b/server/internal/service/worker.go @@ -1,15 +1,14 @@ -package internal +package service import ( + "app/pkg/common" "fmt" "log/slog" "os/exec" "strconv" "strings" + "syscall" "time" - - "github.com/gogf/gf/container/gmap" - "github.com/google/uuid" ) type Worker struct { @@ -27,10 +26,6 @@ const ( workerExec = "/app/agents/bin/worker" ) -var ( - workers = gmap.New(true) -) - func newWorker(channelName string, logFile string, manifestJsonFile string) *Worker { return &Worker{ ChannelName: channelName, @@ -42,12 +37,12 @@ func newWorker(channelName string, logFile string, manifestJsonFile string) *Wor } } -func (w *Worker) start(req *StartReq) (err error) { +func (w *Worker) start(req *common.StartReq) error { shell := fmt.Sprintf("cd /app/agents && nohup %s --manifest %s > %s 2>&1 &", workerExec, w.ManifestJsonFile, w.LogFile) slog.Info("Worker start", "requestId", req.RequestId, "shell", shell, logTag) - if _, err = exec.Command("sh", "-c", shell).CombinedOutput(); err != nil { + if _, err := exec.Command("sh", "-c", shell).CombinedOutput(); err != nil { slog.Error("Worker start failed", "err", err, "requestId", req.RequestId, logTag) - return + return err } shell = fmt.Sprintf("ps aux | grep %s | grep -v grep | awk '{print $2}'", w.ManifestJsonFile) @@ -55,52 +50,28 @@ func (w *Worker) start(req *StartReq) (err error) { output, err := exec.Command("sh", "-c", shell).CombinedOutput() if err != nil { slog.Error("Worker get pid failed", "err", err, "requestId", req.RequestId, logTag) - return + return err } pid, err := strconv.Atoi(strings.TrimSpace(string(output))) if err != nil || pid <= 0 { slog.Error("Worker convert pid failed", "err", err, "pid", pid, "requestId", req.RequestId, logTag) - return + return err } w.Pid = pid - return + return nil } -func (w *Worker) stop(requestId string, channelName string) (err error) { +func (w *Worker) stop(requestId string, channelName string) error { slog.Info("Worker stop start", "channelName", channelName, "requestId", requestId, logTag) - shell := fmt.Sprintf("kill -9 %d", w.Pid) - output, err := exec.Command("sh", "-c", shell).CombinedOutput() + err := syscall.Kill(w.Pid, syscall.SIGTERM) if err != nil { - slog.Error("Worker kill failed", "err", err, "output", output, "channelName", channelName, "worker", w, "requestId", requestId, logTag) - return + slog.Error("Worker kill failed", "err", err, "channelName", channelName, "worker", w, "requestId", requestId, logTag) + return err } - workers.Remove(channelName) - slog.Info("Worker stop end", "channelName", channelName, "worker", w, "requestId", requestId, logTag) - return -} - -func cleanWorker() { - for { - for _, channelName := range workers.Keys() { - worker := workers.Get(channelName).(*Worker) - - nowTs := time.Now().Unix() - if worker.UpdateTs+int64(worker.QuitTimeoutSeconds) < nowTs { - if err := worker.stop(uuid.New().String(), channelName.(string)); err != nil { - slog.Error("Worker cleanWorker failed", "err", err, "channelName", channelName, logTag) - continue - } - - slog.Info("Worker cleanWorker success", "channelName", channelName, "worker", worker, "nowTs", nowTs, logTag) - } - } - - slog.Debug("Worker cleanWorker sleep", "sleep", workerCleanSleepSeconds, logTag) - time.Sleep(workerCleanSleepSeconds * time.Second) - } + return err } diff --git a/server/main.go b/server/main.go index d190387bc2..400ee6b274 100644 --- a/server/main.go +++ b/server/main.go @@ -1,27 +1,32 @@ package main import ( + "context" + "errors" "flag" "log/slog" + "net/http" "os" + "os/signal" "strconv" - - "github.com/tidwall/sjson" + "syscall" + "time" "app/internal" + "app/internal/provider" + "app/internal/service" + "app/third_party/azure" ) func main() { - httpServerConfig := &internal.HttpServerConfig{} - ttsVendorChinese := os.Getenv("TTS_VENDOR_CHINESE") if len(ttsVendorChinese) == 0 { - ttsVendorChinese = internal.TTSVendorAzure + ttsVendorChinese = azure.NAME } ttsVendorEnglish := os.Getenv("TTS_VENDOR_ENGLISH") if len(ttsVendorEnglish) == 0 { - ttsVendorEnglish = internal.TTSVendorAzure + ttsVendorEnglish = azure.NAME } workersMax, err := strconv.Atoi(os.Getenv("WORKERS_MAX")) @@ -34,83 +39,69 @@ func main() { workerQuitTimeoutSeconds = 60 } - flag.StringVar(&httpServerConfig.AppId, "appId", os.Getenv("AGORA_APP_ID"), "agora appid") - flag.StringVar(&httpServerConfig.AppCertificate, "appCertificate", os.Getenv("AGORA_APP_CERTIFICATE"), "agora certificate") - flag.StringVar(&httpServerConfig.Port, "port", ":8080", "http server port") - flag.StringVar(&httpServerConfig.TTSVendorChinese, "ttsVendorChinese", ttsVendorChinese, "tts vendor for chinese") - flag.StringVar(&httpServerConfig.TTSVendorEnglish, "ttsVendorEnglish", ttsVendorEnglish, "tts vendor for english") - flag.IntVar(&httpServerConfig.WorkersMax, "workersMax", workersMax, "workers max") - flag.IntVar(&httpServerConfig.WorkerQuitTimeoutSeconds, "workerQuitTimeoutSeconds", workerQuitTimeoutSeconds, "worker quit timeout seconds") - flag.Parse() + var manifestJsonDir string + flag.StringVar(&manifestJsonDir, "manifestJsonDir", "./agents/", "manifest json dir") - slog.Info("server config", "ttsVendorChinese", httpServerConfig.TTSVendorChinese, "ttsVendorEnglish", httpServerConfig.TTSVendorEnglish, - "workersMax", httpServerConfig.WorkersMax, "workerQuitTimeoutSeconds", httpServerConfig.WorkerQuitTimeoutSeconds) + httpServerConfig := internal.HttpServerConfig{} + flag.StringVar(&httpServerConfig.Address, "port", ":8080", "http server listen address") - processManifest(internal.ManifestJsonFile) - processManifest(internal.ManifestJsonFileElevenlabs) - httpServer := internal.NewHttpServer(httpServerConfig) - httpServer.Start() -} - -func processManifest(manifestJsonFile string) (err error) { - content, err := os.ReadFile(manifestJsonFile) - if err != nil { - slog.Error("read manifest.json failed", "err", err, "manifestJsonFile", manifestJsonFile) - return - } + mainServiceConfig := service.MainServiceConfig{} + flag.StringVar(&mainServiceConfig.AppId, "appId", os.Getenv("AGORA_APP_ID"), "agora appid") + flag.StringVar(&mainServiceConfig.AppCertificate, "appCertificate", os.Getenv("AGORA_APP_CERTIFICATE"), "agora certificate") + flag.StringVar(&mainServiceConfig.TTSVendorChinese, "ttsVendorChinese", ttsVendorChinese, "tts vendor for chinese") + flag.StringVar(&mainServiceConfig.TTSVendorEnglish, "ttsVendorEnglish", ttsVendorEnglish, "tts vendor for english") + flag.IntVar(&mainServiceConfig.WorkersMax, "workersMax", workersMax, "workers max") + flag.IntVar(&mainServiceConfig.WorkerQuitTimeoutSeconds, "workerQuitTimeoutSeconds", workerQuitTimeoutSeconds, "worker quit timeout seconds") - manifestJson := string(content) - - appId := os.Getenv("AGORA_APP_ID") - if appId != "" { - manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="agora_rtc").property.app_id`, appId) - } - - azureSttKey := os.Getenv("AZURE_STT_KEY") - if azureSttKey != "" { - manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="agora_rtc").property.agora_asr_vendor_key`, azureSttKey) - } - - azureSttRegion := os.Getenv("AZURE_STT_REGION") - if azureSttRegion != "" { - manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="agora_rtc").property.agora_asr_vendor_region`, azureSttRegion) - } - - openaiBaseUrl := os.Getenv("OPENAI_BASE_URL") - if openaiBaseUrl != "" { - manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="openai_chatgpt").property.base_url`, openaiBaseUrl) - } - - openaiApiKey := os.Getenv("OPENAI_API_KEY") - if openaiApiKey != "" { - manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="openai_chatgpt").property.api_key`, openaiApiKey) - } - - openaiModel := os.Getenv("OPENAI_MODEL") - if openaiModel != "" { - manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="openai_chatgpt").property.model`, openaiModel) - } + flag.Parse() - proxyUrl := os.Getenv("PROXY_URL") - if proxyUrl != "" { - manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="openai_chatgpt").property.proxy_url`, proxyUrl) - } + slog.Info("server config", + "ttsVendorChinese", mainServiceConfig.TTSVendorChinese, + "ttsVendorEnglish", mainServiceConfig.TTSVendorEnglish, + "workersMax", mainServiceConfig.WorkersMax, + "workerQuitTimeoutSeconds", mainServiceConfig.WorkerQuitTimeoutSeconds) - azureTtsKey := os.Getenv("AZURE_TTS_KEY") - if azureTtsKey != "" { - manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="azure_tts").property.azure_subscription_key`, azureTtsKey) + manifestProvider := provider.NewManifestProvider() + err = manifestProvider.LoadManifest(manifestJsonDir) + if err != nil { + panic(err) } - azureTtsRegion := os.Getenv("AZURE_TTS_REGION") - if azureTtsRegion != "" { - manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="azure_tts").property.azure_subscription_region`, azureTtsRegion) + mainSvc := service.NewMainService(service.MainServiceDepends{ + Config: mainServiceConfig, + ManifestProvider: manifestProvider, + }) + + httpServer := internal.NewHttpServer(internal.HttpServerDepends{ + Config: httpServerConfig, + MainSvc: mainSvc, + }) + + errCh := make(chan error, 1) + go func() { + defer close(errCh) + err := httpServer.Run() + if errors.Is(err, http.ErrServerClosed) { + errCh <- nil + } + errCh <- err + }() + + sigCh := make(chan os.Signal, 1) + defer close(sigCh) + signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGINT) + <-sigCh + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Second*3) + err = httpServer.Shutdown(ctx) + if err != nil { + slog.Error("httpServer Shutdown error", "err", err) } + defer cancel() // fix warning lostcancel - elevenlabsTtsKey := os.Getenv("ELEVENLABS_TTS_KEY") - if elevenlabsTtsKey != "" { - manifestJson, _ = sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="elevenlabs_tts").property.api_key`, elevenlabsTtsKey) + err = <-errCh + if err != nil { + panic(err) } - - err = os.WriteFile(manifestJsonFile, []byte(manifestJson), 0644) - return } diff --git a/server/pkg/common/code.go b/server/pkg/common/code.go new file mode 100644 index 0000000000..a77cb6f540 --- /dev/null +++ b/server/pkg/common/code.go @@ -0,0 +1,29 @@ +package common + +type Code struct { + Code string + Msg string +} + +var ( + CodeOk = NewCode("0", "ok") + CodeSuccess = NewCode("0", "success") + + CodeErrParamsInvalid = NewCode("10000", "params invalid") + CodeErrWorkersLimit = NewCode("10001", "workers limit") + CodeErrChannelNotExisted = NewCode("10002", "channel not existed") + CodeErrChannelExisted = NewCode("10003", "channel existed") + CodeErrChannelEmpty = NewCode("10004", "channel empty") + CodeErrGenerateTokenFailed = NewCode("10005", "generate token failed") + + CodeErrProcessManifestFailed = NewCode("10100", "process manifest json failed") + CodeErrStartWorkerFailed = NewCode("10101", "start worker failed") + CodeErrStopAppFailed = NewCode("10102", "stop worker failed") +) + +func NewCode(code string, msg string) *Code { + return &Code{ + Code: code, + Msg: msg, + } +} diff --git a/server/pkg/common/request.go b/server/pkg/common/request.go new file mode 100644 index 0000000000..d46adad5b8 --- /dev/null +++ b/server/pkg/common/request.go @@ -0,0 +1,39 @@ +package common + +type Language string + +const ( + LanguageChinese Language = "zh-CN" + LanguageEnglish Language = "en-US" +) + +type VoiceType string + +const ( + VoiceTypeMale = "male" + VoiceTypeFemale = "female" +) + +type PingReq struct { + RequestId string `form:"request_id,omitempty" json:"request_id,omitempty"` + ChannelName string `form:"channel_name,omitempty" json:"channel_name,omitempty"` +} + +type StartReq struct { + RequestId string `form:"request_id,omitempty" json:"request_id,omitempty"` + AgoraAsrLanguage Language `form:"agora_asr_language,omitempty" json:"agora_asr_language,omitempty"` + ChannelName string `form:"channel_name,omitempty" json:"channel_name,omitempty"` + RemoteStreamId uint32 `form:"remote_stream_id,omitempty" json:"remote_stream_id,omitempty"` + VoiceType VoiceType `form:"voice_type,omitempty" json:"voice_type,omitempty"` +} + +type StopReq struct { + RequestId string `form:"request_id,omitempty" json:"request_id,omitempty"` + ChannelName string `form:"channel_name,omitempty" json:"channel_name,omitempty"` +} + +type GenerateTokenReq struct { + RequestId string `form:"request_id,omitempty" json:"request_id,omitempty"` + ChannelName string `form:"channel_name,omitempty" json:"channel_name,omitempty"` + Uid uint32 `form:"uid,omitempty" json:"uid,omitempty"` +} diff --git a/server/pkg/provider/tts.go b/server/pkg/provider/tts.go new file mode 100644 index 0000000000..2e351aced3 --- /dev/null +++ b/server/pkg/provider/tts.go @@ -0,0 +1,24 @@ +package provider + +import "app/pkg/common" + +var registeredTts = make(map[string]ITtsProvider) + +type ITtsProvider interface { + Name() string + ProcessManifest(manifestJson string, language common.Language, voiceType common.VoiceType) (string, error) +} + +func RegisterTts(provider ITtsProvider) { + if provider == nil { + panic("cannot register a nil ITtsProvider") + } + if provider.Name() == "" { + panic("cannot register ITtsProvider with empty string result for Name()") + } + registeredTts[provider.Name()] = provider +} + +func GetTts(name string) ITtsProvider { + return registeredTts[name] +} diff --git a/server/third_party/azure/tts.go b/server/third_party/azure/tts.go new file mode 100644 index 0000000000..8ff4ba79ab --- /dev/null +++ b/server/third_party/azure/tts.go @@ -0,0 +1,52 @@ +package azure + +import ( + "app/pkg/common" + "app/pkg/provider" + "errors" + + "github.com/tidwall/sjson" +) + +const NAME string = "azure" + +func init() { + provider.RegisterTts(NewAzureTtsProvider()) +} + +type AzureTtsProvider struct { + voiceNameMap map[common.Language]map[common.VoiceType]string +} + +func NewAzureTtsProvider() provider.ITtsProvider { + return &AzureTtsProvider{ + voiceNameMap: map[common.Language]map[common.VoiceType]string{ + common.LanguageChinese: { + common.VoiceTypeMale: "zh-CN-YunxiNeural", + common.VoiceTypeFemale: "zh-CN-XiaoxiaoNeural", + }, + common.LanguageEnglish: { + common.VoiceTypeMale: "en-US-BrianNeural", + common.VoiceTypeFemale: "en-US-JaneNeural", + }, + }, + } +} + +// Name implements provider.ITtsProvider. +func (p *AzureTtsProvider) Name() string { + return NAME +} + +// ProcessManifest implements provider.ITtsProvider. +func (p *AzureTtsProvider) ProcessManifest(manifestJson string, language common.Language, voiceType common.VoiceType) (string, error) { + voiceTypeMap, ok := p.voiceNameMap[language] + if !ok { + return "", errors.New("unknow language") + } + voiceName, ok := voiceTypeMap[voiceType] + if !ok { + return "", errors.New("unknow voiceType") + } + return sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="azure_tts").property.azure_synthesis_voice_name`, voiceName) +} diff --git a/server/third_party/elevenlabs/tts.go b/server/third_party/elevenlabs/tts.go new file mode 100644 index 0000000000..b01519cacf --- /dev/null +++ b/server/third_party/elevenlabs/tts.go @@ -0,0 +1,52 @@ +package elevenlabs + +import ( + "app/pkg/common" + "app/pkg/provider" + "errors" + + "github.com/tidwall/sjson" +) + +const NAME string = "elevenlabs" + +func init() { + provider.RegisterTts(NewElevenLabsTtsProvider()) +} + +type ElevenLabsTtsProvider struct { + voiceNameMap map[common.Language]map[common.VoiceType]string +} + +func NewElevenLabsTtsProvider() provider.ITtsProvider { + return &ElevenLabsTtsProvider{ + voiceNameMap: map[common.Language]map[common.VoiceType]string{ + common.LanguageChinese: { + common.VoiceTypeMale: "pNInz6obpgDQGcFmaJgB", // Adam + common.VoiceTypeFemale: "Xb7hH8MSUJpSbSDYk0k2", // Alice + }, + common.LanguageEnglish: { + common.VoiceTypeMale: "pNInz6obpgDQGcFmaJgB", // Adam + common.VoiceTypeFemale: "Xb7hH8MSUJpSbSDYk0k2", // Alice + }, + }, + } +} + +// Name implements provider.ITtsProvider. +func (p *ElevenLabsTtsProvider) Name() string { + return NAME +} + +// ProcessManifest implements provider.ITtsProvider. +func (p *ElevenLabsTtsProvider) ProcessManifest(manifestJson string, language common.Language, voiceType common.VoiceType) (string, error) { + voiceTypeMap, ok := p.voiceNameMap[language] + if !ok { + return "", errors.New("unknow language") + } + voiceName, ok := voiceTypeMap[voiceType] + if !ok { + return "", errors.New("unknow voiceType") + } + return sjson.Set(manifestJson, `predefined_graphs.0.nodes.#(name=="elevenlabs_tts").property.voice_id`, voiceName) +}