Skip to content

Commit 5ddd00e

Browse files
authored
Clean up augmented schema to contain only types used in the resulting schema (#234)
resolves #233
1 parent fc8f780 commit 5ddd00e

File tree

7 files changed

+246
-1266
lines changed

7 files changed

+246
-1266
lines changed

core/src/main/kotlin/org/neo4j/graphql/SchemaBuilder.kt

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,29 @@ class SchemaBuilder(
119119
.filter { it.name == queryTypeName || it.name == mutationTypeName || it.name == subscriptionTypeName }
120120
.forEach { type -> handler.forEach { h -> h.augmentType(type) } }
121121

122-
// TODO copy over only the types used in the source schema
123-
typeDefinitionRegistry.merge(neo4jTypeDefinitionRegistry)
122+
val types = mutableListOf<Type<*>>()
123+
neo4jTypeDefinitionRegistry.directiveDefinitions.values
124+
.filterNot { typeDefinitionRegistry.getDirectiveDefinition(it.name).isPresent }
125+
.forEach { directiveDefinition ->
126+
typeDefinitionRegistry.add(directiveDefinition)
127+
directiveDefinition.inputValueDefinitions.forEach { types.add(it.type) }
128+
}
129+
typeDefinitionRegistry.types()
130+
.values
131+
.flatMap { typeDefinition ->
132+
when (typeDefinition) {
133+
is ImplementingTypeDefinition -> typeDefinition.fieldDefinitions
134+
.flatMap { fieldDefinition -> fieldDefinition.inputValueDefinitions.map { it.type } + fieldDefinition.type }
135+
is InputObjectTypeDefinition -> typeDefinition.inputValueDefinitions.map { it.type }
136+
else -> emptyList()
137+
}
138+
}
139+
.forEach { types.add(it) }
140+
types
141+
.map { TypeName(it.name()) }
142+
.filterNot { typeDefinitionRegistry.hasType(it) }
143+
.mapNotNull { neo4jTypeDefinitionRegistry.getType(it).unwrap() }
144+
.forEach { typeDefinitionRegistry.add(it) }
124145
}
125146

126147
/**

core/src/test/kotlin/org/neo4j/graphql/utils/GraphQLSchemaTestSuite.kt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class GraphQLSchemaTestSuite(fileName: String) : AsciiDocTestSuite(
5757

5858
diff(expectedSchema, augmentedSchema)
5959
diff(augmentedSchema, expectedSchema)
60+
targetSchemaBlock.adjustedCode = SCHEMA_PRINTER.print(augmentedSchema)
6061
} catch (e: Throwable) {
6162
if (ignore) {
6263
Assumptions.assumeFalse(true, e.message)
@@ -65,9 +66,7 @@ class GraphQLSchemaTestSuite(fileName: String) : AsciiDocTestSuite(
6566
Assertions.fail<Throwable>(e)
6667
}
6768
val actualSchema = SCHEMA_PRINTER.print(augmentedSchema)
68-
targetSchemaBlock.adjustedCode = actualSchema + "\n" +
69-
// this is added since the SCHEMA_PRINTER is not able to print global directives
70-
javaClass.getResource("/lib_directives.graphql").readText()
69+
targetSchemaBlock.adjustedCode = actualSchema
7170
throw AssertionFailedError("augmented schema differs for '$title'",
7271
expectedSchema?.let { SCHEMA_PRINTER.print(it) } ?: targetSchema,
7372
actualSchema,
@@ -84,7 +83,7 @@ class GraphQLSchemaTestSuite(fileName: String) : AsciiDocTestSuite(
8483
private val METHOD_PATTERN = Pattern.compile("(add|delete|update|merge|create)(.*)")
8584

8685
private val SCHEMA_PRINTER = SchemaPrinter(SchemaPrinter.Options.defaultOptions()
87-
.includeDirectives(true)
86+
.includeDirectives(false)
8887
.includeScalarTypes(true)
8988
.includeSchemaDefinition(true)
9089
.includeIntrospectionTypes(false)

0 commit comments

Comments
 (0)