diff --git a/pkg/util/workspace/util.go b/pkg/util/workspace/util.go new file mode 100644 index 0000000..a960255 --- /dev/null +++ b/pkg/util/workspace/util.go @@ -0,0 +1,142 @@ +package workspace + +import ( + "errors" + "fmt" + + v1 "kusionstack.io/kusion-api-go/api.kusion.io/v1" +) + +// GetProjectModuleConfigs returns the module configs of a specified project, whose key is the module name, should be called after ValidateModuleConfigs. +// If got empty module configs, return nil config and nil error. +func GetProjectModuleConfigs(configs v1.ModuleConfigs, projectName string) (map[string]v1.GenericConfig, error) { + if len(configs) == 0 { + return nil, nil + } + if projectName == "" { + return nil, errors.New("empty project name") + } + + projectConfigs := make(map[string]v1.GenericConfig) + for name, cfg := range configs { + moduleConfig, err := getProjectModuleConfig(cfg, projectName) + if moduleConfig == nil { + continue + } + if err != nil { + return nil, fmt.Errorf("%w, module name: %s", err, name) + } + if len(moduleConfig) != 0 { + projectConfigs[name] = moduleConfig + } + } + + return projectConfigs, nil +} + +// GetProjectModuleConfig returns the module config of a specified project, should be called after ValidateModuleConfig. +// If got empty module config, return nil config and nil error. +func GetProjectModuleConfig(config *v1.ModuleConfig, projectName string) (v1.GenericConfig, error) { + if config == nil { + return nil, nil + } + if projectName == "" { + return nil, errors.New("empty project name") + } + + return getProjectModuleConfig(config, projectName) +} + +// getProjectModuleConfig gets the module config of a specified project without checking the correctness of project name. +func getProjectModuleConfig(config *v1.ModuleConfig, projectName string) (v1.GenericConfig, error) { + projectCfg := config.Configs.Default + if len(projectCfg) == 0 { + projectCfg = make(v1.GenericConfig) + } + + for name, cfg := range config.Configs.ModulePatcherConfigs { + if name == v1.DefaultBlock { + continue + } + // check the project is assigned in the block or not. + var contain bool + for _, project := range cfg.ProjectSelector { + if projectName == project { + contain = true + break + } + } + if contain { + for k, v := range cfg.GenericConfig { + if k == v1.ProjectSelectorField { + continue + } + projectCfg[k] = v + } + break + } + } + + return projectCfg, nil +} + +// GetInt32PointerFromGenericConfig returns the value of the key in config which should be of type int. +// If exist but not int, return error. If not exist, return nil. +func GetInt32PointerFromGenericConfig(config v1.GenericConfig, key string) (*int32, error) { + value, ok := config[key] + if !ok { + return nil, nil + } + i, ok := value.(int) + if !ok { + return nil, fmt.Errorf("the value of %s is not int", key) + } + res := int32(i) + return &res, nil +} + +// GetStringFromGenericConfig returns the value of the key in config which should be of type string. +// If exist but not string, return error; If not exist, return "", nil. +func GetStringFromGenericConfig(config v1.GenericConfig, key string) (string, error) { + value, ok := config[key] + if !ok { + return "", nil + } + s, ok := value.(string) + if !ok { + return "", fmt.Errorf("the value of %s is not string", key) + } + return s, nil +} + +// GetMapFromGenericConfig returns the value of the key in config which should be of type map[string]any. +// If exist but not map[string]any, return error; If not exist, return nil, nil. +func GetMapFromGenericConfig(config v1.GenericConfig, key string) (map[string]any, error) { + value, ok := config[key] + if !ok { + return nil, nil + } + m, ok := value.(v1.GenericConfig) + if !ok { + return nil, fmt.Errorf("the value of %s is not map", key) + } + return m, nil +} + +// GetStringMapFromGenericConfig returns the value of the key in config which should be of type map[string]string. +// If exist but not map[string]string, return error; If not exist, return nil, nil. +func GetStringMapFromGenericConfig(config v1.GenericConfig, key string) (map[string]string, error) { + m, err := GetMapFromGenericConfig(config, key) + if err != nil { + return nil, err + } + stringMap := make(map[string]string) + for k, v := range m { + stringValue, ok := v.(string) + if !ok { + return nil, fmt.Errorf("the value of %s.%s is not string", key, k) + } + stringMap[k] = stringValue + } + return stringMap, nil +} diff --git a/pkg/util/workspace/util_test.go b/pkg/util/workspace/util_test.go new file mode 100644 index 0000000..99a139b --- /dev/null +++ b/pkg/util/workspace/util_test.go @@ -0,0 +1,285 @@ +package workspace + +import ( + "testing" + + "github.com/stretchr/testify/assert" + v1 "kusionstack.io/kusion-api-go/api.kusion.io/v1" +) + +func mockValidModuleConfigs() map[string]*v1.ModuleConfig { + return map[string]*v1.ModuleConfig{ + "mysql": { + Path: "ghcr.io/kusionstack/mysql", + Version: "0.1.0", + Configs: v1.Configs{ + Default: v1.GenericConfig{ + "type": "aws", + "version": "5.7", + "instanceType": "db.t3.micro", + }, + ModulePatcherConfigs: v1.ModulePatcherConfigs{ + "smallClass": { + GenericConfig: v1.GenericConfig{ + "instanceType": "db.t3.small", + }, + ProjectSelector: []string{"foo", "bar"}, + }, + }, + }, + }, + "network": { + Path: "ghcr.io/kusionstack/network", + Version: "0.1.0", + Configs: v1.Configs{ + Default: v1.GenericConfig{ + "type": "aws", + }, + }, + }, + } +} + +func mockGenericConfig() v1.GenericConfig { + return v1.GenericConfig{ + "int_type_field": 2, + "string_type_field": "kusion", + "map_type_field": v1.GenericConfig{ + "k1": "v1", + "k2": 2, + }, + "string_map_type_field": v1.GenericConfig{ + "k1": "v1", + "k2": "v2", + }, + } +} + +func Test_GetProjectModuleConfigs(t *testing.T) { + testcases := []struct { + name string + projectName string + moduleConfigs v1.ModuleConfigs + success bool + expectedProjectConfigs map[string]v1.GenericConfig + }{ + { + name: "successfully get project module configs", + projectName: "foo", + moduleConfigs: mockValidModuleConfigs(), + success: true, + expectedProjectConfigs: map[string]v1.GenericConfig{ + "mysql": { + "type": "aws", + "version": "5.7", + "instanceType": "db.t3.small", + }, + "network": { + "type": "aws", + }, + }, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + cfg, err := GetProjectModuleConfigs(tc.moduleConfigs, tc.projectName) + assert.Equal(t, tc.success, err == nil) + assert.Equal(t, tc.expectedProjectConfigs, cfg) + }) + } +} + +func Test_GetProjectModuleConfig(t *testing.T) { + testcases := []struct { + name string + success bool + projectName string + moduleConfig *v1.ModuleConfig + expectedProjectConfig v1.GenericConfig + }{ + { + name: "successfully get default project module config", + projectName: "baz", + moduleConfig: mockValidModuleConfigs()["mysql"], + success: true, + expectedProjectConfig: v1.GenericConfig{ + "type": "aws", + "version": "5.7", + "instanceType": "db.t3.micro", + }, + }, + { + name: "successfully get override project module config", + projectName: "foo", + moduleConfig: mockValidModuleConfigs()["mysql"], + success: true, + expectedProjectConfig: v1.GenericConfig{ + "type": "aws", + "version": "5.7", + "instanceType": "db.t3.small", + }, + }, + { + name: "failed to get config empty project name", + projectName: "", + moduleConfig: mockValidModuleConfigs()["mysql"], + success: false, + expectedProjectConfig: nil, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + cfg, err := GetProjectModuleConfig(tc.moduleConfig, tc.projectName) + assert.Equal(t, tc.success, err == nil) + assert.Equal(t, tc.expectedProjectConfig, cfg) + }) + } +} + +func Test_GetIntFieldFromGenericConfig(t *testing.T) { + r2 := int32(2) + + testcases := []struct { + name string + key string + success bool + expectedValue *int32 + }{ + { + name: "successfully get int type field", + key: "int_type_field", + success: true, + expectedValue: &r2, + }, + { + name: "get not exist field", + key: "not_exist", + success: true, + expectedValue: nil, + }, + { + name: "get field failed not int type", + key: "string_type_field", + success: false, + expectedValue: nil, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + value, err := GetInt32PointerFromGenericConfig(mockGenericConfig(), tc.key) + assert.Equal(t, tc.success, err == nil) + assert.Equal(t, tc.expectedValue, value) + }) + } +} + +func Test_GetStringFieldFromGenericConfig(t *testing.T) { + testcases := []struct { + name string + key string + success bool + expectedValue string + }{ + { + name: "successfully get string type field", + key: "string_type_field", + success: true, + expectedValue: "kusion", + }, + { + name: "get not exist field", + key: "not_exist", + success: true, + expectedValue: "", + }, + { + name: "get field failed not string type", + key: "int_type_field", + success: false, + expectedValue: "", + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + value, err := GetStringFromGenericConfig(mockGenericConfig(), tc.key) + assert.Equal(t, tc.success, err == nil) + assert.Equal(t, tc.expectedValue, value) + }) + } +} + +func Test_GetMapFieldFromGenericConfig(t *testing.T) { + testcases := []struct { + name string + key string + success bool + expectedValue map[string]any + }{ + { + name: "successfully get map type field", + key: "map_type_field", + success: true, + expectedValue: map[string]any{ + "k1": "v1", + "k2": 2, + }, + }, + { + name: "get not exist field", + key: "not_exist", + success: true, + expectedValue: nil, + }, + { + name: "get field failed not map type", + key: "int_type_field", + success: false, + expectedValue: nil, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + value, err := GetMapFromGenericConfig(mockGenericConfig(), tc.key) + assert.Equal(t, tc.success, err == nil) + assert.Equal(t, tc.expectedValue, value) + }) + } +} + +func Test_GetStringMapFieldFromGenericConfig(t *testing.T) { + testcases := []struct { + name string + key string + success bool + expectedValue map[string]string + }{ + { + name: "successfully get string map type field", + key: "string_map_type_field", + success: true, + expectedValue: map[string]string{ + "k1": "v1", + "k2": "v2", + }, + }, + { + name: "get field failed map key not string", + key: "map_type_field", + success: false, + expectedValue: nil, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + value, err := GetStringMapFromGenericConfig(mockGenericConfig(), tc.key) + assert.Equal(t, tc.success, err == nil) + assert.Equal(t, tc.expectedValue, value) + }) + } +}