Skip to content

Commit

Permalink
Add Constant.referencedConstants (#499)
Browse files Browse the repository at this point in the history
  • Loading branch information
benjamin-bader authored Dec 6, 2022
1 parent e685215 commit aa82c00
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class Constant private constructor (
override val isDeprecated: Boolean
get() = mixin.isDeprecated

var referencedConstants: List<Constant> = emptyList()
private set

internal constructor(element: ConstElement, namespaces: Map<NamespaceScope, String>, type: ThriftType? = null)
: this(element, UserElementMixin(element, namespaces), type)

Expand All @@ -64,8 +67,46 @@ class Constant private constructor (
type_ = linker.resolveType(element.type)
}

internal fun linkReferencedConstants(linker: Linker) {
referencedConstants = type.accept(ConstantReferenceVisitor(value, linker))
}

internal fun validate(linker: Linker) {
validate(linker, element.value, type)
detectCycles(linker, mutableMapOf(), mutableListOf(this))
}

private fun detectCycles(linker: Linker, visitStates: MutableMap<Constant, VisitState>, path: MutableList<Constant>) {
if (visitStates[this] == VisitState.VISITING) {
val message = path.joinToString(
separator = "\n\t -> ",
prefix = "Cycle detected while validating Thrift constants: \n\t") { elem ->
"${elem.name} (${elem.location.path}:${elem.location.line})"
}
throw IllegalStateException(message)
}

visitStates[this] = VisitState.VISITING

for (const in referencedConstants) {
if (visitStates[const] == VisitState.VISITED) {
continue
}

path.add(const)
const.detectCycles(linker, visitStates, path)
path.removeLast()
}

visitStates[this] = VisitState.VISITED
}

/**
* Used to implement a depth-first search for cycle detection during validation.
*/
private enum class VisitState {
VISITING,
VISITED
}

/**
Expand Down Expand Up @@ -418,6 +459,116 @@ class Constant private constructor (
}
}

private class ConstantReferenceVisitor(
private val cve: ConstValueElement,
private val linker: Linker,
) : ThriftType.Visitor<List<Constant>> {
override fun visitVoid(voidType: BuiltinType): List<Constant> = emptyList()

private fun getScalarConstantReference(): List<Constant> {
if (cve !is IdentifierValueElement) {
return emptyList()
}

val ref = linker.lookupConst(cve.value)
?: throw IllegalStateException("Unrecognized const identifier: ${cve.value}")

return listOf(ref)
}

override fun visitBool(boolType: BuiltinType): List<Constant> {
if (cve is IdentifierValueElement) {
val maybeRef = linker.lookupConst(cve.value)
if (maybeRef != null) {
return listOf(maybeRef)
}
// Bool constants can have IdentifierValueElement values that are not
// const references; that's likely the case here.
}
return emptyList()
}
override fun visitByte(byteType: BuiltinType) = getScalarConstantReference()
override fun visitI16(i16Type: BuiltinType) = getScalarConstantReference()
override fun visitI32(i32Type: BuiltinType) = getScalarConstantReference()
override fun visitI64(i64Type: BuiltinType) = getScalarConstantReference()
override fun visitDouble(doubleType: BuiltinType) = getScalarConstantReference()
override fun visitString(stringType: BuiltinType) = getScalarConstantReference()
override fun visitBinary(binaryType: BuiltinType) = getScalarConstantReference()

override fun visitEnum(enumType: EnumType): List<Constant> {
if (cve is IdentifierValueElement) {
val maybeRef = linker.lookupConst(cve.value)
if (maybeRef != null) {
return listOf(maybeRef)
}
// Enum constants can have IdentifierValueElement values that are not
// const references; that's likely the case here.
}
return emptyList()
}

override fun visitList(listType: ListType) = visitListOrSet(listType.elementType)

override fun visitSet(setType: SetType) = visitListOrSet(setType.elementType)

private fun visitListOrSet(elementType: ThriftType): List<Constant> {
return when (cve) {
is IdentifierValueElement -> getScalarConstantReference()

is ListValueElement -> cve.value.flatMap { elem ->
val visitor = ConstantReferenceVisitor(elem, linker)
elementType.accept(visitor)
}

else -> error("wat")
}
}

override fun visitMap(mapType: MapType): List<Constant> {
return when (cve) {
is IdentifierValueElement -> getScalarConstantReference()

is MapValueElement -> cve.value.values.flatMap { elem ->
val visitor = ConstantReferenceVisitor(elem, linker)
mapType.valueType.accept(visitor)
}

else -> error("no")
}
}

override fun visitStruct(structType: StructType): List<Constant> {
if (cve is IdentifierValueElement) {
return getScalarConstantReference()
}

if (cve !is MapValueElement) {
error("unpossible")
}

val fieldsByName = structType.fields.associateBy { it.name }

return cve.value.flatMap { (key, value) ->
if (key !is LiteralValueElement) {
error("wtf")
}

val fieldName = key.value
val field = fieldsByName[fieldName] ?: error("nope")
val visitor = ConstantReferenceVisitor(value, linker)
field.type.accept(visitor)
}
}

override fun visitTypedef(typedefType: TypedefType): List<Constant> {
return typedefType.trueType.accept(this)
}

override fun visitService(serviceType: ServiceType): List<Constant> {
error("No such thing as a Service-typed constant")
}
}

companion object {
@VisibleForTesting
internal fun validate(symbolTable: SymbolTable, value: ConstValueElement, expected: ThriftType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ internal class Linker(

// Only validate the schema if linking succeeded; no point otherwise.
if (!reporter.hasError) {
linkConstantReferences()

validateTypedefs()
validateConstants()
validateStructs()
Expand Down Expand Up @@ -268,6 +270,16 @@ internal class Linker(
}
}

private fun linkConstantReferences() {
for (constant in program.constants) {
try {
constant.linkReferencedConstants(this)
} catch (e: IllegalStateException) {
reporter.error(constant.location, e.message ?: "Error linking constant references")
}
}
}

private fun validateStructs() {
for (struct in program.structs) {
struct.validate(this)
Expand Down Expand Up @@ -433,7 +445,8 @@ internal class Linker(
val qualifiedName = symbol.substring(ix + 1)
val expectedPath = "$includeName.thrift"
constant = program.includes
.filter { p -> p.location.path == expectedPath }
.asSequence()
.filter { p -> p.location.path == expectedPath } // TODO: Should this be ==, or endsWith?
.mapNotNull { p -> p.constantMap[qualifiedName] }
.firstOrNull()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,61 @@ class LoaderTest {
load(thrift)
}

@Test
fun constantsWithCircularReferences() {
val thrift = """
struct Node {
1: required string key;
2: optional Node value;
}
const Node A = {
"key": "foo",
"value": B
}
const Node B = {
"key": "bar",
"value": A
}
""".trimIndent()

val e = shouldThrow<LoadFailedException> { load(thrift) }
e.message shouldContain "Cycle detected while validating Thrift constants"
}

@Test
fun constantCycleWithMutuallyDependentStructs() {
val thrift = """
struct TweedleDee {
1: TweedleDum brother;
}
struct TweedleDum {
1: TweedleDee brother;
}
const TweedleDee TWEEDLE_DEE = { "brother": TWEEDLE_DUM }
const TweedleDum TWEEDLE_DUM = { "brother": TWEEDLE_DEE }
""".trimIndent()

val e = shouldThrow<LoadFailedException> { load(thrift) }
e.message shouldContain "Cycle detected while validating Thrift constants"
}

@Test
fun populatesReferencedConstants() {
val thrift = """
const string A = "a";
const string B = "b";
const list<string> STRS = [A, B];
""".trimIndent()

val schema = load(thrift)
val strs = schema.constants.last()
strs.referencedConstants.map { it.name} shouldBe listOf("A", "B")
}

@Test
fun unionWithOneDefaultValue() {
val thrift = """
Expand Down

0 comments on commit aa82c00

Please sign in to comment.