Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions internal/dao/tenant_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ func (dao *TenantModelDAO) GetModelByProviderIDAndInstanceIDAndModelName(provide
return &model, nil
}

// GetModelByProviderIDAndInstanceIDAndModelTypeAndModelName gets a model by provider ID, instance ID, model type and model name
func (dao *TenantModelDAO) GetModelByProviderIDAndInstanceIDAndModelTypeAndModelName(providerID, instanceID, modelType, modelName string) (*entity.TenantModel, error) {
var model entity.TenantModel
err := DB.Where("provider_id = ? AND instance_id = ? AND model_type = ? AND model_name = ?", providerID, instanceID, modelType, modelName).First(&model).Error
if err != nil {
return nil, err
}
return &model, nil
}

// GetModelsByInstanceID get all models by instance ID
func (dao *TenantModelDAO) GetModelsByInstanceID(instanceID string) ([]*entity.TenantModel, error) {
var models []*entity.TenantModel
Expand All @@ -75,3 +85,13 @@ func (dao *TenantModelDAO) GetModelsByInstanceID(instanceID string) ([]*entity.T
}
return models, nil
}

// GetModelsByProviderIDsAndInstanceIDs get all models by provider IDs and instance IDs
func (dao *TenantModelDAO) GetModelsByProviderIDsAndInstanceIDs(providerIDs, instanceIDs []string) ([]*entity.TenantModel, error) {
var models []*entity.TenantModel
err := DB.Where("provider_id IN ? AND instance_id IN ?", providerIDs, instanceIDs).Find(&models).Error
if err != nil {
return nil, err
}
return models, nil
}
10 changes: 10 additions & 0 deletions internal/dao/tenant_model_instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,13 @@ func (dao *TenantModelInstanceDAO) DeleteByProviderIDAndInstanceName(providerID,
result := DB.Unscoped().Where("provider_id = ? and instance_name = ?", providerID, instanceName).Delete(&entity.TenantModelInstance{})
return result.RowsAffected, result.Error
}

// GetByProviderIDs get all instances by provider IDs
func (dao *TenantModelInstanceDAO) GetByProviderIDs(providerIDs []string) ([]*entity.TenantModelInstance, error) {
var instances []*entity.TenantModelInstance
err := DB.Where("provider_id IN ?", providerIDs).Find(&instances).Error
if err != nil {
return nil, err
}
return instances, nil
}
10 changes: 10 additions & 0 deletions internal/dao/tenant_model_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,13 @@ func (dao *TenantModelProviderDAO) ListByID(id string) ([]string, error) {
Pluck("provider_name", &providerNames).Error
return providerNames, err
}

// GetByTenantID get all model providers by tenant ID
func (dao *TenantModelProviderDAO) GetByTenantID(tenantID string) ([]*entity.TenantModelProvider, error) {
var providers []*entity.TenantModelProvider
err := DB.Where("tenant_id = ?", tenantID).Find(&providers).Error
if err != nil {
return nil, err
}
return providers, nil
}
25 changes: 22 additions & 3 deletions internal/entity/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ func (pm *ProviderManager) ListProviders() ([]map[string]interface{}, error) {

modelTypeSet := make(map[string]struct{})
for _, model := range provider.Models {

for _, modelType := range model.ModelTypes {
modelTypeSet[modelType] = struct{}{}
}
Expand All @@ -287,6 +288,9 @@ func (pm *ProviderManager) ListProviders() ([]map[string]interface{}, error) {
"model_types": modelTypes,
"url_suffix": provider.URLSuffix,
}
if (len(modelTypes) == 0) {
continue
}
providers = append(providers, providerData)
}

Expand All @@ -305,10 +309,25 @@ func (pm *ProviderManager) GetProviderByName(providerName string) (map[string]in
return nil, fmt.Errorf("provider '%s' not found", providerName)
}

modelTypeSet := make(map[string]struct{})
for _, model := range provider.Models {
if len(model.ModelTypes) == 0 {
continue
}
for _, modelType := range model.ModelTypes {
modelTypeSet[modelType] = struct{}{}
}
}

var modelTypes []string
for modelType := range modelTypeSet {
modelTypes = append(modelTypes, modelType)
}

providerInfo := map[string]interface{}{
"name": provider.Name,
"base_url": provider.URL,
"total_models": len(provider.Models),
"name": provider.Name,
"url": provider.URL,
"model_types": modelTypes,
}

return providerInfo, nil
Expand Down
59 changes: 58 additions & 1 deletion internal/handler/tenant.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (h *TenantHandler) GetModels(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"message": "success",
"data": defaultModels,
"data": gin.H{"models": defaultModels},
})
}

Expand Down Expand Up @@ -119,6 +119,63 @@ func (h *TenantHandler) SetModels(c *gin.Context) {
})
}

// GetAddedModels lists all added models for the current user's tenant
// @Summary List Added Models
// @Description List all models added to the current user's tenant
// @Tags models
// @Accept json
// @Produce json
// @Security ApiKeyAuth
// @Param type query string false "Model type filter (chat, embedding, rerank, asr, vision, tts, ocr)"
// @Success 200 {object} map[string]interface{}
// @Router /api/v1/models [get]
func (h *TenantHandler) GetAddedModels(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}

// Get tenant ID for the user
tenantInfos, err := h.tenantService.GetTenantInfo(user.ID)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeExceptionError,
"message": err.Error(),
"data": nil,
})
return
}

if tenantInfos == nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeDataError,
"message": "Tenant not found",
"data": nil,
})
return
}

// Get optional model type filter from query params
modelTypeFilter := c.Query("type")

addedModels, err := h.tenantService.ListTenantAddedModels(tenantInfos.TenantID, modelTypeFilter)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": common.CodeExceptionError,
"message": err.Error(),
"data": nil,
})
return
}

c.JSON(http.StatusOK, gin.H{
"code": common.CodeSuccess,
"message": "success",
"data": addedModels,
})
}

// TenantInfo get tenant information
// @Summary Get Tenant Information
// @Description Get current user's tenant information (owner tenant)
Expand Down
7 changes: 4 additions & 3 deletions internal/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ func (r *Router) Setup(engine *gin.Engine) {
// provider pool route group
provider := v1.Group("/providers")
{
provider.GET("/", r.providerHandler.ListProviders)
provider.GET("", r.providerHandler.ListProviders)
provider.PUT("/", r.providerHandler.AddProvider)
provider.GET("/:provider_name", r.providerHandler.ShowProvider)
provider.DELETE("/:provider_name", r.providerHandler.DeleteProvider)
Expand Down Expand Up @@ -301,8 +301,9 @@ func (r *Router) Setup(engine *gin.Engine) {

model := v1.Group("/models")
{
model.GET("/", r.tenantHandler.GetModels)
model.PATCH("/", r.tenantHandler.SetModels)
model.GET("/default", r.tenantHandler.GetModels)
model.PATCH("/default", r.tenantHandler.SetModels)
model.GET("", r.tenantHandler.GetAddedModels) // GET /api/v1/models - list tenant added models
}

connector := v1.Group("/connectors")
Expand Down
28 changes: 15 additions & 13 deletions internal/service/model_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,16 +306,16 @@ func (m *ModelProviderService) ListProviderInstances(providerName, userID string
var extra map[string]string
err = json.Unmarshal([]byte(instance.Extra), &extra)
if err != nil {
return nil, common.CodeServerError, err
extra = make(map[string]string)
}

result = append(result, map[string]interface{}{
"id": instance.ID,
"instanceName": instance.InstanceName,
"providerID": instance.ProviderID,
"apiKey": instance.APIKey,
"status": instance.Status,
"extra": instance.Extra,
"api_key": instance.APIKey,
"id": instance.ID,
"instance_name": instance.InstanceName,
"provider_id": instance.ProviderID,
"region": extra["region"],
"status": instance.Status,
})
}

Expand Down Expand Up @@ -351,15 +351,15 @@ func (m *ModelProviderService) ShowProviderInstance(providerName, instanceName,
var extra map[string]string
err = json.Unmarshal([]byte(instance.Extra), &extra)
if err != nil {
return nil, common.CodeServerError, err
extra = make(map[string]string)
}

result := map[string]interface{}{
"id": instance.ID,
"instanceName": instance.InstanceName,
"providerID": instance.ProviderID,
"status": instance.Status,
"region": extra["region"],
"id": instance.ID,
"instance_name": instance.InstanceName,
"provider_id": instance.ProviderID,
"status": instance.Status,
"region": extra["region"],
}

return result, common.CodeSuccess, nil
Expand Down Expand Up @@ -744,6 +744,8 @@ func (m *ModelProviderService) ListInstanceModels(providerName, instanceName, us
for _, model := range allModels {
// convert model["name"] to string
modelName := model["name"].(string)
model["model_type"] = model["model_types"]
delete(model, "model_types")
if modelNames[modelName] {
model["status"] = "inactive"
} else {
Expand Down
Loading