diff --git a/internal/cmd/compare.go b/internal/cmd/compare.go index d8fc9d8..5754a28 100644 --- a/internal/cmd/compare.go +++ b/internal/cmd/compare.go @@ -69,7 +69,7 @@ func compare(provider string, repository string, oldCommit string, newCommit str var schNew schema.PackageSpec if newCommit == "--local" { usr, _ := user.Current() - basePath := fmt.Sprintf("%s/go/src/github.com/pulumi/%s", usr.HomeDir, provider) + basePath := fmt.Sprintf("%s/go/src/github.com/pulumi/pulumi-%s", usr.HomeDir, provider) schemaFile := pkg.StandardSchemaPath(provider) schemaPath := filepath.Join(basePath, schemaFile) var err error @@ -301,15 +301,17 @@ func breakingChanges(oldSchema, newSchema schema.PackageSpec) *diagtree.Node { func compareSchemas(out io.Writer, provider string, oldSchema, newSchema schema.PackageSpec, maxChanges int) { fmt.Fprintf(out, "### Does the PR have any schema changes?\n\n") violations := breakingChanges(oldSchema, newSchema) + concerningTypeStructure(provider, &oldSchema, &newSchema, violations) + displayedViolations := new(bytes.Buffer) lenViolations := violations.Display(displayedViolations, maxChanges) switch lenViolations { case 0: fmt.Fprintln(out, "Looking good! No breaking changes found.") case 1: - fmt.Fprintln(out, "Found 1 breaking change: ") + fmt.Fprintln(out, "Found 1 issue: ") default: - fmt.Fprintf(out, "Found %d breaking changes:\n", lenViolations) + fmt.Fprintf(out, "Found %d issues:\n", lenViolations) } _, err := out.Write(displayedViolations.Bytes()) diff --git a/internal/cmd/types.go b/internal/cmd/types.go new file mode 100644 index 0000000..f223f5d --- /dev/null +++ b/internal/cmd/types.go @@ -0,0 +1,66 @@ +// Copyright 2016-2024, Pulumi Corporation. + +package cmd + +import ( + "strings" + + "github.com/pulumi/pulumi/pkg/v3/codegen/schema" + "github.com/pulumi/schema-tools/internal/util/diagtree" +) + +// Navigate through each resource in the schema and find how many types are referenced from that resource, +// recursively. If the count increased for a resource and is above 200, emit a warning. +func concerningTypeStructure(provider string, oldSchema *schema.PackageSpec, + newSchema *schema.PackageSpec, violations *diagtree.Node) { + oldTypeCounts := typeCountPerResource(oldSchema) + newTypeCounts := typeCountPerResource(newSchema) + section := violations.Label("Max Type Count per Resource") + for resName, newCount := range newTypeCounts { + oldCount, _ := oldTypeCounts[resName] + if newCount > oldCount && newCount > 200 { + msg := section.Value(formatName(provider, resName)) + msg.SetDescription(diagtree.Warn, "number of types increased from %d to %d", oldCount, newCount) + } + } +} + +func typeCountPerResource(schema *schema.PackageSpec) map[string]int { + res := make(map[string]int) + for name, r := range schema.Resources { + visitedTypes := make(map[string]bool) + countTypes(schema, r.InputProperties, visitedTypes) + countTypes(schema, r.Properties, visitedTypes) + res[name] = len(visitedTypes) + } + return res +} + +func countTypes(schema *schema.PackageSpec, props map[string]schema.PropertySpec, visitedTypes map[string]bool) { + for _, prop := range props { + countType(schema, &prop.TypeSpec, visitedTypes) + } +} + +func countType(schema *schema.PackageSpec, ts *schema.TypeSpec, visitedTypes map[string]bool) { + if ts.AdditionalProperties != nil { + countType(schema, ts.AdditionalProperties, visitedTypes) + } + if ts.Items != nil { + countType(schema, ts.Items, visitedTypes) + } + if len(ts.OneOf) > 0 { + for _, t := range ts.OneOf { + countType(schema, &t, visitedTypes) + } + } + if strings.HasPrefix(ts.Ref, "#/types/") { + typeName := strings.TrimPrefix(ts.Ref, "#/types/") + if _, ok := visitedTypes[typeName]; ok { + return + } + visitedTypes[typeName] = true + typ := schema.Types[typeName] + countTypes(schema, typ.Properties, visitedTypes) + } +}