From f67b88d2e7869db90b2200a71def9486b0e6875c Mon Sep 17 00:00:00 2001 From: chanx <1243304602@qq.com> Date: Thu, 28 May 2026 16:41:32 +0800 Subject: [PATCH] fix(go): resolve provider API inconsistencies - Add and implement /api/v1/models endpoint - Fix response format inconsistency in /api/v1/models/default - Fix response format and data inconsistency in /api/v1/providers - Fix data format inconsistency in /:provider_name/instances/:instance_name/models --- internal/dao/tenant_model.go | 20 ++ internal/dao/tenant_model_instance.go | 10 + internal/dao/tenant_model_provider.go | 10 + internal/entity/model.go | 25 ++- internal/handler/tenant.go | 59 ++++- internal/router/router.go | 7 +- internal/service/model_service.go | 28 +-- internal/service/tenant.go | 309 +++++++++++++++++++++++--- 8 files changed, 417 insertions(+), 51 deletions(-) diff --git a/internal/dao/tenant_model.go b/internal/dao/tenant_model.go index fd69c3ca415..4f5bafb0855 100644 --- a/internal/dao/tenant_model.go +++ b/internal/dao/tenant_model.go @@ -66,6 +66,16 @@ func (dao *TenantModelDAO) GetModelByProviderIDAndInstanceIDAndModelName(provide return &model, nil } +// GetModelByProviderIDAndInstanceIDAndModelTypeAndModelName gets a model by provider ID, instance ID, model type and model name +func (dao *TenantModelDAO) GetModelByProviderIDAndInstanceIDAndModelTypeAndModelName(providerID, instanceID, modelType, modelName string) (*entity.TenantModel, error) { + var model entity.TenantModel + err := DB.Where("provider_id = ? AND instance_id = ? AND model_type = ? AND model_name = ?", providerID, instanceID, modelType, modelName).First(&model).Error + if err != nil { + return nil, err + } + return &model, nil +} + // GetModelsByInstanceID get all models by instance ID func (dao *TenantModelDAO) GetModelsByInstanceID(instanceID string) ([]*entity.TenantModel, error) { var models []*entity.TenantModel @@ -75,3 +85,13 @@ func (dao *TenantModelDAO) GetModelsByInstanceID(instanceID string) ([]*entity.T } return models, nil } + +// GetModelsByProviderIDsAndInstanceIDs get all models by provider IDs and instance IDs +func (dao *TenantModelDAO) GetModelsByProviderIDsAndInstanceIDs(providerIDs, instanceIDs []string) ([]*entity.TenantModel, error) { + var models []*entity.TenantModel + err := DB.Where("provider_id IN ? AND instance_id IN ?", providerIDs, instanceIDs).Find(&models).Error + if err != nil { + return nil, err + } + return models, nil +} diff --git a/internal/dao/tenant_model_instance.go b/internal/dao/tenant_model_instance.go index 97eb4304e23..6e0b15cf11e 100644 --- a/internal/dao/tenant_model_instance.go +++ b/internal/dao/tenant_model_instance.go @@ -64,3 +64,13 @@ func (dao *TenantModelInstanceDAO) DeleteByProviderIDAndInstanceName(providerID, result := DB.Unscoped().Where("provider_id = ? and instance_name = ?", providerID, instanceName).Delete(&entity.TenantModelInstance{}) return result.RowsAffected, result.Error } + +// GetByProviderIDs get all instances by provider IDs +func (dao *TenantModelInstanceDAO) GetByProviderIDs(providerIDs []string) ([]*entity.TenantModelInstance, error) { + var instances []*entity.TenantModelInstance + err := DB.Where("provider_id IN ?", providerIDs).Find(&instances).Error + if err != nil { + return nil, err + } + return instances, nil +} diff --git a/internal/dao/tenant_model_provider.go b/internal/dao/tenant_model_provider.go index fd75353bdbb..76b99a3096c 100644 --- a/internal/dao/tenant_model_provider.go +++ b/internal/dao/tenant_model_provider.go @@ -72,3 +72,13 @@ func (dao *TenantModelProviderDAO) ListByID(id string) ([]string, error) { Pluck("provider_name", &providerNames).Error return providerNames, err } + +// GetByTenantID get all model providers by tenant ID +func (dao *TenantModelProviderDAO) GetByTenantID(tenantID string) ([]*entity.TenantModelProvider, error) { + var providers []*entity.TenantModelProvider + err := DB.Where("tenant_id = ?", tenantID).Find(&providers).Error + if err != nil { + return nil, err + } + return providers, nil +} diff --git a/internal/entity/model.go b/internal/entity/model.go index a0b69f7020f..0a9224eeec4 100644 --- a/internal/entity/model.go +++ b/internal/entity/model.go @@ -271,6 +271,7 @@ func (pm *ProviderManager) ListProviders() ([]map[string]interface{}, error) { modelTypeSet := make(map[string]struct{}) for _, model := range provider.Models { + for _, modelType := range model.ModelTypes { modelTypeSet[modelType] = struct{}{} } @@ -287,6 +288,9 @@ func (pm *ProviderManager) ListProviders() ([]map[string]interface{}, error) { "model_types": modelTypes, "url_suffix": provider.URLSuffix, } + if (len(modelTypes) == 0) { + continue + } providers = append(providers, providerData) } @@ -305,10 +309,25 @@ func (pm *ProviderManager) GetProviderByName(providerName string) (map[string]in return nil, fmt.Errorf("provider '%s' not found", providerName) } + modelTypeSet := make(map[string]struct{}) + for _, model := range provider.Models { + if len(model.ModelTypes) == 0 { + continue + } + for _, modelType := range model.ModelTypes { + modelTypeSet[modelType] = struct{}{} + } + } + + var modelTypes []string + for modelType := range modelTypeSet { + modelTypes = append(modelTypes, modelType) + } + providerInfo := map[string]interface{}{ - "name": provider.Name, - "base_url": provider.URL, - "total_models": len(provider.Models), + "name": provider.Name, + "url": provider.URL, + "model_types": modelTypes, } return providerInfo, nil diff --git a/internal/handler/tenant.go b/internal/handler/tenant.go index 4505cb7bcad..74ee055b5fe 100644 --- a/internal/handler/tenant.go +++ b/internal/handler/tenant.go @@ -73,7 +73,7 @@ func (h *TenantHandler) GetModels(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "code": common.CodeSuccess, "message": "success", - "data": defaultModels, + "data": gin.H{"models": defaultModels}, }) } @@ -119,6 +119,63 @@ func (h *TenantHandler) SetModels(c *gin.Context) { }) } +// GetAddedModels lists all added models for the current user's tenant +// @Summary List Added Models +// @Description List all models added to the current user's tenant +// @Tags models +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param type query string false "Model type filter (chat, embedding, rerank, asr, vision, tts, ocr)" +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/models [get] +func (h *TenantHandler) GetAddedModels(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + // Get tenant ID for the user + tenantInfos, err := h.tenantService.GetTenantInfo(user.ID) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeExceptionError, + "message": err.Error(), + "data": nil, + }) + return + } + + if tenantInfos == nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeDataError, + "message": "Tenant not found", + "data": nil, + }) + return + } + + // Get optional model type filter from query params + modelTypeFilter := c.Query("type") + + addedModels, err := h.tenantService.ListTenantAddedModels(tenantInfos.TenantID, modelTypeFilter) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeExceptionError, + "message": err.Error(), + "data": nil, + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "message": "success", + "data": addedModels, + }) +} + // TenantInfo get tenant information // @Summary Get Tenant Information // @Description Get current user's tenant information (owner tenant) diff --git a/internal/router/router.go b/internal/router/router.go index d343d8ee1b7..4df198782b1 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -271,7 +271,7 @@ func (r *Router) Setup(engine *gin.Engine) { // provider pool route group provider := v1.Group("/providers") { - provider.GET("/", r.providerHandler.ListProviders) + provider.GET("", r.providerHandler.ListProviders) provider.PUT("/", r.providerHandler.AddProvider) provider.GET("/:provider_name", r.providerHandler.ShowProvider) provider.DELETE("/:provider_name", r.providerHandler.DeleteProvider) @@ -301,8 +301,9 @@ func (r *Router) Setup(engine *gin.Engine) { model := v1.Group("/models") { - model.GET("/", r.tenantHandler.GetModels) - model.PATCH("/", r.tenantHandler.SetModels) + model.GET("/default", r.tenantHandler.GetModels) + model.PATCH("/default", r.tenantHandler.SetModels) + model.GET("", r.tenantHandler.GetAddedModels) // GET /api/v1/models - list tenant added models } connector := v1.Group("/connectors") diff --git a/internal/service/model_service.go b/internal/service/model_service.go index 2ebbcb34e03..5b003912e8e 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -306,16 +306,16 @@ func (m *ModelProviderService) ListProviderInstances(providerName, userID string var extra map[string]string err = json.Unmarshal([]byte(instance.Extra), &extra) if err != nil { - return nil, common.CodeServerError, err + extra = make(map[string]string) } result = append(result, map[string]interface{}{ - "id": instance.ID, - "instanceName": instance.InstanceName, - "providerID": instance.ProviderID, - "apiKey": instance.APIKey, - "status": instance.Status, - "extra": instance.Extra, + "api_key": instance.APIKey, + "id": instance.ID, + "instance_name": instance.InstanceName, + "provider_id": instance.ProviderID, + "region": extra["region"], + "status": instance.Status, }) } @@ -351,15 +351,15 @@ func (m *ModelProviderService) ShowProviderInstance(providerName, instanceName, var extra map[string]string err = json.Unmarshal([]byte(instance.Extra), &extra) if err != nil { - return nil, common.CodeServerError, err + extra = make(map[string]string) } result := map[string]interface{}{ - "id": instance.ID, - "instanceName": instance.InstanceName, - "providerID": instance.ProviderID, - "status": instance.Status, - "region": extra["region"], + "id": instance.ID, + "instance_name": instance.InstanceName, + "provider_id": instance.ProviderID, + "status": instance.Status, + "region": extra["region"], } return result, common.CodeSuccess, nil @@ -744,6 +744,8 @@ func (m *ModelProviderService) ListInstanceModels(providerName, instanceName, us for _, model := range allModels { // convert model["name"] to string modelName := model["name"].(string) + model["model_type"] = model["model_types"] + delete(model, "model_types") if modelNames[modelName] { model["status"] = "inactive" } else { diff --git a/internal/service/tenant.go b/internal/service/tenant.go index 7ff19c5f373..e106f660975 100644 --- a/internal/service/tenant.go +++ b/internal/service/tenant.go @@ -23,6 +23,7 @@ import ( "ragflow/internal/dao" "ragflow/internal/engine" "ragflow/internal/entity" + "ragflow/internal/server" "strings" ) @@ -404,7 +405,19 @@ func (s *TenantService) GetDefaultModelName(tenantID string, modelType entity.Mo return modelID, nil } -func (s *TenantService) GetModelInfo(tenantID string, defaultModel string, modelType string) (*string, *string, *string, bool, error) { +// MODEL_TAG_TO_TYPE maps model type tags to standard model type names +// This matches Python's MODEL_TAG_TO_TYPE in models_api_service.py +var MODEL_TAG_TO_TYPE = map[string]string{ + "chat": "chat", + "embedding": "embedding", + "rerank": "rerank", + "asr": "speech2text", + "vision": "image2text", + "tts": "tts", + "ocr": "ocr", +} + +func (s *TenantService) GetModelInfo(tenantID string, defaultModel string, modelType string) (*string, *string, *string, string, bool, error) { // normally the model string is: modelName@instanceName@providerName, sometimes it's just modelName@providerName // for the 1st case, parse defaultChatModel into three parts defaultChatModelParts := strings.Split(defaultModel, "@") @@ -422,47 +435,52 @@ func (s *TenantService) GetModelInfo(tenantID string, defaultModel string, model *instanceName = "default" modelName = &defaultChatModelParts[0] } else { - return nil, nil, nil, false, fmt.Errorf("invalid model string: %s", defaultModel) + return nil, nil, nil, "", false, fmt.Errorf("invalid model string: %s", defaultModel) } - if modelType == "ocr" { + // Convert model type tag to standard model type name (matches Python's MODEL_TAG_TO_TYPE) + mappedModelType, ok := MODEL_TAG_TO_TYPE[modelType] + if !ok { + mappedModelType = modelType + } + + if mappedModelType == "ocr" { if *providerName == "infiniflow" && *instanceName == "default" && *modelName == "deepdoc" { - return providerName, instanceName, modelName, true, nil + return providerName, instanceName, modelName, mappedModelType, true, nil } } // Check if the provider and instance exists modelProvider, err := s.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, *providerName) if err != nil { - return nil, nil, nil, false, err + return nil, nil, nil, "", false, err } modelInstance, err := s.modelInstanceDAO.GetByProviderIDAndInstanceName(modelProvider.ID, *instanceName) if err != nil { - return nil, nil, nil, false, err + return nil, nil, nil, "", false, err } modelSchema, err := dao.GetModelProviderManager().GetModelByName(*providerName, *modelName) - if err != nil { - return nil, nil, nil, false, err - } - - if !modelSchema.ModelTypeMap[modelType] { - return nil, nil, nil, false, fmt.Errorf("model %s isn't a chat model", *modelName) + if err == nil && !modelSchema.ModelTypeMap[mappedModelType] { + return nil, nil, nil, "", false, fmt.Errorf("model %s isn't a %s model", *modelName, mappedModelType) } var modelEntity *entity.TenantModel - modelEntity, err = s.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(modelProvider.ID, modelInstance.ID, *modelName) + modelEntity, err = s.modelDAO.GetModelByProviderIDAndInstanceIDAndModelTypeAndModelName(modelProvider.ID, modelInstance.ID, mappedModelType, *modelName) if err != nil { errString := err.Error() if !strings.Contains(errString, "record not found") { - return nil, nil, nil, false, err + return nil, nil, nil, "", false, err } } - enable := modelEntity == nil + // enable = true if: + // 1. modelEntity is nil (no record exists), OR + // 2. modelEntity exists but status is NOT "inactive" + enable := modelEntity == nil || modelEntity.Status != "inactive" - return providerName, instanceName, modelName, enable, nil + return providerName, instanceName, modelName, mappedModelType, enable, nil } @@ -480,68 +498,68 @@ func (s *TenantService) ListTenantDefaultModels(userID string) ([]ModelItem, err var result []ModelItem - defaultChatModelProvider, defaultChatModelInstance, defaultChatModelName, defaultChatModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.LLMID, "chat") + defaultChatModelProvider, defaultChatModelInstance, defaultChatModelName, defaultChatModelType, defaultChatModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.LLMID, "chat") if err == nil { result = append(result, ModelItem{ ModelProvider: defaultChatModelProvider, ModelInstance: defaultChatModelInstance, ModelName: defaultChatModelName, - ModelType: "chat", + ModelType: defaultChatModelType, Enable: defaultChatModelEnable, }) } - defaultEmbeddingModelProvider, defaultEmbeddingModelInstance, defaultEmbeddingModelName, defaultEmbeddingModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.EmbDID, "embedding") + defaultEmbeddingModelProvider, defaultEmbeddingModelInstance, defaultEmbeddingModelName, defaultEmbeddingModelType, defaultEmbeddingModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.EmbDID, "embedding") if err == nil { result = append(result, ModelItem{ ModelProvider: defaultEmbeddingModelProvider, ModelInstance: defaultEmbeddingModelInstance, ModelName: defaultEmbeddingModelName, - ModelType: "embedding", + ModelType: defaultEmbeddingModelType, Enable: defaultEmbeddingModelEnable, }) } - defaultRerankModelProvider, defaultRerankModelInstance, defaultRerankModelName, defaultRerankModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.RerankID, "rerank") + defaultRerankModelProvider, defaultRerankModelInstance, defaultRerankModelName, defaultRerankModelType, defaultRerankModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.RerankID, "rerank") if err == nil { result = append(result, ModelItem{ ModelProvider: defaultRerankModelProvider, ModelInstance: defaultRerankModelInstance, ModelName: defaultRerankModelName, - ModelType: "rerank", + ModelType: defaultRerankModelType, Enable: defaultRerankModelEnable, }) } - defaultASRModelProvider, defaultASRModelInstance, defaultASRModelName, defaultASREnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.ASRID, "asr") + defaultASRModelProvider, defaultASRModelInstance, defaultASRModelName, defaultASRModelType, defaultASREnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.ASRID, "asr") if err == nil { result = append(result, ModelItem{ ModelProvider: defaultASRModelProvider, ModelInstance: defaultASRModelInstance, ModelName: defaultASRModelName, - ModelType: "asr", + ModelType: defaultASRModelType, Enable: defaultASREnable, }) } - defaultImage2TextModelProvider, defaultImage2TextModelInstance, defaultImage2TextModelName, defaultImage2TextModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.Img2TxtID, "vision") + defaultImage2TextModelProvider, defaultImage2TextModelInstance, defaultImage2TextModelName, defaultImage2TextModelType, defaultImage2TextModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.Img2TxtID, "vision") if err == nil { result = append(result, ModelItem{ ModelProvider: defaultImage2TextModelProvider, ModelInstance: defaultImage2TextModelInstance, ModelName: defaultImage2TextModelName, - ModelType: "vision", + ModelType: defaultImage2TextModelType, Enable: defaultImage2TextModelEnable, }) } - defaultOCRModelProvider, defaultOCRModelInstance, defaultOCRModelName, defaultOCRModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.OCRID, "ocr") + defaultOCRModelProvider, defaultOCRModelInstance, defaultOCRModelName, defaultOCRModelType, defaultOCRModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, ownedTenant.OCRID, "ocr") if err == nil { result = append(result, ModelItem{ ModelProvider: defaultOCRModelProvider, ModelInstance: defaultOCRModelInstance, ModelName: defaultOCRModelName, - ModelType: "ocr", + ModelType: defaultOCRModelType, Enable: defaultOCRModelEnable, }) } @@ -550,13 +568,13 @@ func (s *TenantService) ListTenantDefaultModels(userID string) ([]ModelItem, err return result, nil } - defaultTTSModelProvider, defaultTTSModelInstance, defaultTTSModelName, defaultTTSModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, *ownedTenant.TTSID, "tts") + defaultTTSModelProvider, defaultTTSModelInstance, defaultTTSModelName, defaultTTSModelType, defaultTTSModelEnable, err := s.GetModelInfo(ownedTenant.TenantID, *ownedTenant.TTSID, "tts") if err == nil { result = append(result, ModelItem{ ModelProvider: defaultTTSModelProvider, ModelInstance: defaultTTSModelInstance, ModelName: defaultTTSModelName, - ModelType: "tts", + ModelType: defaultTTSModelType, Enable: defaultTTSModelEnable, }) } @@ -586,7 +604,7 @@ func (s *TenantService) checkModelAvailable(tenantID, providerName, instanceName } var modelEntity *entity.TenantModel - modelEntity, err = s.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(modelProvider.ID, modelInstance.ID, modelName) + modelEntity, err = s.modelDAO.GetModelByProviderIDAndInstanceIDAndModelTypeAndModelName(modelProvider.ID, modelInstance.ID, modelType, modelName) if err != nil || modelEntity != nil { var errString = err.Error() if errString == "record not found" { @@ -654,3 +672,232 @@ func (s *TenantService) SetTenantDefaultModels(userID, modelProvider, modelInsta return nil } + +// AddedModelItem represents a model in the list of added models +type AddedModelItem struct { + ModelType []string `json:"model_type"` + Name string `json:"name"` + ProviderID string `json:"provider_id"` + ProviderName string `json:"provider_name"` + InstanceID string `json:"instance_id"` + InstanceName string `json:"instance_name"` +} + +// ListTenantAddedModels lists all added models for a tenant +// This implements the Python models_api_service.list_tenant_added_models function +func (s *TenantService) ListTenantAddedModels(tenantID string, modelTypeFilter string) ([]AddedModelItem, error) { + // Step 1: Verify tenant exists + tenant, err := s.tenantDAO.GetByID(tenantID) + if err != nil { + return nil, fmt.Errorf("tenant not found") + } + if tenant == nil { + return nil, fmt.Errorf("tenant not found") + } + + // Step 2: Normalize model type filter (lowercase if provided) + if modelTypeFilter != "" { + modelTypeFilter = strings.ToLower(modelTypeFilter) + } + + // Step 3: Get all providers for tenant + providers, err := s.modelProviderDAO.GetByTenantID(tenantID) + if err != nil { + return nil, err + } + if len(providers) == 0 { + return []AddedModelItem{}, nil + } + + // Step 4: Get all instances for those providers + providerIDs := make([]string, len(providers)) + providerInfoMap := make(map[string]*entity.TenantModelProvider) + for i, p := range providers { + providerIDs[i] = p.ID + providerInfoMap[p.ID] = p + } + + instances, err := s.modelInstanceDAO.GetByProviderIDs(providerIDs) + if err != nil { + return nil, err + } + if len(instances) == 0 { + return []AddedModelItem{}, nil + } + + // Step 5: Build provider_instance_map: map[provider_name][]instance + providerInstanceMap := make(map[string][]*entity.TenantModelInstance) + for _, inst := range instances { + providerName := "" + if p, ok := providerInfoMap[inst.ProviderID]; ok { + providerName = p.ProviderName + } + providerInstanceMap[providerName] = append(providerInstanceMap[providerName], inst) + } + + // Step 6: Get all model records + instanceIDs := make([]string, len(instances)) + instanceInfoMap := make(map[string]*entity.TenantModelInstance) + for i, inst := range instances { + instanceIDs[i] = inst.ID + instanceInfoMap[inst.ID] = inst + } + + modelRecords, err := s.modelDAO.GetModelsByProviderIDsAndInstanceIDs(providerIDs, instanceIDs) + if err != nil { + return nil, err + } + + // Step 7: Filter by model_type if provided and build model_record_map + modelRecordMap := make(map[string][]*entity.TenantModel) + for _, model := range modelRecords { + if modelTypeFilter != "" && model.ModelType != modelTypeFilter { + continue + } + key := fmt.Sprintf("%s_%s_%s", model.ProviderID, model.InstanceID, model.ModelName) + modelRecordMap[key] = append(modelRecordMap[key], model) + } + + // Step 8: Build provider_names list for factory matching + providerNames := make([]string, len(providers)) + for i, p := range providers { + providerNames[i] = p.ProviderName + } + + var addedModels []AddedModelItem + modelKeyInFactory := make(map[string]bool) + + // Step 9: Iterate through factory providers + factories := server.GetModelProviders() + for _, factory := range factories { + // Check if this factory is in our tenant's providers + found := false + for _, pn := range providerNames { + if pn == factory.Name { + found = true + break + } + } + if !found { + continue + } + + factoryInstances, ok := providerInstanceMap[factory.Name] + if !ok || len(factoryInstances) == 0 { + continue + } + + // Step 10: Iterate through each LLM in the factory + for _, llm := range factory.LLMs { + // Apply model type filter + if modelTypeFilter != "" && llm.ModelType != modelTypeFilter { + continue + } + + // Step 11: For each factory instance, check model records + for _, factoryInstance := range factoryInstances { + modelRecordKey := fmt.Sprintf("%s_%s_%s", factoryInstance.ProviderID, factoryInstance.ID, llm.LLMName) + modelKeyInFactory[modelRecordKey] = true + + manualModifiedModels := modelRecordMap[modelRecordKey] + + // Determine active and inactive model types + var activeModelTypes []string + var inactiveModelTypes []string + for _, manualModel := range manualModifiedModels { + if manualModel.Status == "inactive" { + inactiveModelTypes = append(inactiveModelTypes, manualModel.ModelType) + } else { + activeModelTypes = append(activeModelTypes, manualModel.ModelType) + } + } + + // Calculate final model_types: (set([llm["model_type"]] + active_model_types) - set(inactive_model_types)) + modelTypesSet := make(map[string]bool) + modelTypesSet[llm.ModelType] = true + for _, t := range activeModelTypes { + modelTypesSet[t] = true + } + for _, t := range inactiveModelTypes { + delete(modelTypesSet, t) + } + + if len(modelTypesSet) == 0 { + continue + } + + var modelTypes []string + for t := range modelTypesSet { + modelTypes = append(modelTypes, t) + } + + providerName := "" + if p, ok := providerInfoMap[factoryInstance.ProviderID]; ok { + providerName = p.ProviderName + } + + addedModels = append(addedModels, AddedModelItem{ + ModelType: modelTypes, + Name: llm.LLMName, + ProviderID: factoryInstance.ProviderID, + ProviderName: providerName, + InstanceID: factoryInstance.ID, + InstanceName: factoryInstance.InstanceName, + }) + } + } + } + + // Step 12: Handle manual_added_models (models in tenant_model but not in factory) + for modelRecordKey, modelRecords := range modelRecordMap { + if modelKeyInFactory[modelRecordKey] { + continue + } + + if len(modelRecords) == 0 { + continue + } + + // Parse key: provider_id_instance_id_model_name + parts := strings.Split(modelRecordKey, "_") + if len(parts) < 3 { + continue + } + providerID := parts[0] + instanceID := parts[1] + modelName := strings.Join(parts[2:], "_") // model name might contain underscores + + // Get active model types + var modelTypes []string + for _, model := range modelRecords { + if model.Status != "inactive" { + modelTypes = append(modelTypes, model.ModelType) + } + } + + if len(modelTypes) == 0 { + continue + } + + providerName := "" + if p, ok := providerInfoMap[providerID]; ok { + providerName = p.ProviderName + } + + instanceName := "" + if inst, ok := instanceInfoMap[instanceID]; ok { + instanceName = inst.InstanceName + } + + addedModels = append(addedModels, AddedModelItem{ + ModelType: modelTypes, + Name: modelName, + ProviderID: providerID, + ProviderName: providerName, + InstanceID: instanceID, + InstanceName: instanceName, + }) + } + + return addedModels, nil +}