Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

alias-type on record and adt types #2197

Merged
merged 3 commits into from
Mar 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/ast/analysis/typesystem/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ std::set<TypeAttribute> TypeAnalysis::getTypeAttributes(const Argument* arg) con
TypeAttribute::Record};
}
for (const auto& type : types) {
typeAttributes.insert(getTypeAttribute(type));
typeAttributes.insert(getTypeAttribute(skipAliasesType(type)));
}
return typeAttributes;
}
Expand Down Expand Up @@ -530,7 +530,7 @@ void TypeAnnotationPrinter::branchOnArgument(const Argument* cur, const Type& ty
} else if (isA<NilConstant>(*cur)) {
print_(type_identity<NilConstant>(), *as<NilConstant>(cur));
} else if (isA<RecordInit>(*cur)) {
print_(type_identity<RecordInit>(), *as<RecordInit>(cur), *as<RecordType>(type));
print_(type_identity<RecordInit>(), *as<RecordInit>(cur), *as<RecordType>(getBaseType(&type)));
} else if (isA<BranchInit>(*cur)) {
print_(type_identity<BranchInit>(), *as<BranchInit>(cur));
} else if (isA<IntrinsicFunctor>(*cur)) {
Expand Down
4 changes: 2 additions & 2 deletions src/ast/analysis/typesystem/TypeConstrainsAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ namespace souffle::ast::analysis {

void TypeConstraintsAnalysis::visitSink(const Atom& atom) {
iterateOverAtom(atom, [&](const Argument& argument, const Type& attributeType) {
if (isA<RecordType>(attributeType)) {
addConstraint(isSubtypeOf(getVar(argument), getBaseType(&attributeType)));
if (isA<RecordType>(skipAliasesType(attributeType))) {
addConstraint(isSubtypeOf(getVar(argument), attributeType));
return;
}
for (auto& constantType : typeEnv.getConstantTypes()) {
Expand Down
13 changes: 10 additions & 3 deletions src/ast/analysis/typesystem/TypeConstraints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ TypeConstraint isSubtypeOf(const TypeVar& variable, const Type& type) {
TypeSet& assignment = assignments[variable];

if (assignment.isAll()) {
assignment = TypeSet(type);
assignment = TypeSet(skipAliasesType(type));
return true;
}

TypeSet newAssignment;
for (const Type& t : assignment) {
assert(!isA<AliasType>(t));
newAssignment.insert(getGreatestCommonSubtypes(t, type));
}

Expand Down Expand Up @@ -90,6 +91,7 @@ TypeConstraint hasSuperTypeInSet(const TypeVar& var, TypeSet values) {

TypeSet newAssigments;
for (const Type& type : assigments) {
assert(!isA<AliasType>(type));
bool existsSuperTypeInValues =
any_of(values, [&type](const Type& value) { return isSubtypeOf(type, value); });
if (existsSuperTypeInValues) {
Expand Down Expand Up @@ -138,6 +140,7 @@ TypeConstraint subtypesOfTheSameBaseType(const TypeVar& left, const TypeVar& rig
// Left
if (!assigmentsLeft.isAll()) {
for (const Type& type : assigmentsLeft) {
assert(!isA<AliasType>(type));
if (isA<SubsetType>(type) || isA<ConstantType>(type)) {
baseTypesLeft.insert(getBaseType(&type));
}
Expand All @@ -146,6 +149,7 @@ TypeConstraint subtypesOfTheSameBaseType(const TypeVar& left, const TypeVar& rig
// Right
if (!assigmentsRight.isAll()) {
for (const Type& type : assigmentsRight) {
assert(!isA<AliasType>(type));
if (isA<SubsetType>(type) || isA<ConstantType>(type)) {
baseTypesRight.insert(getBaseType(&type));
}
Expand Down Expand Up @@ -174,6 +178,7 @@ TypeConstraint subtypesOfTheSameBaseType(const TypeVar& left, const TypeVar& rig

// Allow types if they are subtypes of any of the common base types.
for (const Type& type : assigmentsLeft) {
assert(!isA<AliasType>(type));
bool isSubtypeOfCommonBaseType = any_of(baseTypes.begin(), baseTypes.end(),
[&type](const Type& baseType) { return isSubtypeOf(type, baseType); });
if (isSubtypeOfCommonBaseType) {
Expand All @@ -182,6 +187,7 @@ TypeConstraint subtypesOfTheSameBaseType(const TypeVar& left, const TypeVar& rig
}

for (const Type& type : assigmentsRight) {
assert(!isA<AliasType>(type));
bool isSubtypeOfCommonBaseType = any_of(baseTypes.begin(), baseTypes.end(),
[&type](const Type& baseType) { return isSubtypeOf(type, baseType); });
if (isSubtypeOfCommonBaseType) {
Expand Down Expand Up @@ -330,12 +336,13 @@ TypeConstraint isSubtypeOfComponent(
TypeSet newRecordTypes;

for (const Type& type : recordTypes) {
assert(!isA<AliasType>(type));
// A type must be either a record type or a subset of a record type
if (!isOfKind(type, TypeAttribute::Record)) {
if (!isBaseOfKind(type, TypeAttribute::Record)) {
continue;
}

const auto& typeAsRecord = *as<RecordType>(type);
const auto& typeAsRecord = *as<RecordType>(getBaseType(&type));

// Wrong size => skip.
if (typeAsRecord.getFields().size() <= index) {
Expand Down
75 changes: 61 additions & 14 deletions src/ast/analysis/typesystem/TypeSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,26 @@ bool isOfRootType(const Type& type, const Type& root) {
}

bool isOfKind(const Type& type, TypeAttribute kind) {
auto& t = skipAliasesType(type);

if (kind == TypeAttribute::Record) {
return isA<RecordType>(t);
} else if (kind == TypeAttribute::ADT) {
return isA<AlgebraicDataType>(t);
}

return isOfRootType(t, t.getTypeEnvironment().getConstantType(kind));
}

bool isBaseOfKind(const Type& type, TypeAttribute kind) {
if (auto subset = as<SubsetType>(type)) {
return isBaseOfKind(subset->getBaseType(), kind);
}

if (auto alias = as<AliasType>(type)) {
return isBaseOfKind(alias->getAliasType(), kind);
}

if (kind == TypeAttribute::Record) {
return isA<RecordType>(type);
} else if (kind == TypeAttribute::ADT) {
Expand All @@ -207,17 +227,17 @@ bool isOfKind(const TypeSet& typeSet, TypeAttribute kind) {

std::string getTypeQualifier(const Type& type) {
std::string kind = [&]() {
if (isOfKind(type, TypeAttribute::Signed)) {
if (isBaseOfKind(type, TypeAttribute::Signed)) {
return "i";
} else if (isOfKind(type, TypeAttribute::Unsigned)) {
} else if (isBaseOfKind(type, TypeAttribute::Unsigned)) {
return "u";
} else if (isOfKind(type, TypeAttribute::Float)) {
} else if (isBaseOfKind(type, TypeAttribute::Float)) {
return "f";
} else if (isOfKind(type, TypeAttribute::Symbol)) {
} else if (isBaseOfKind(type, TypeAttribute::Symbol)) {
return "s";
} else if (isOfKind(type, TypeAttribute::Record)) {
} else if (isBaseOfKind(type, TypeAttribute::Record)) {
return "r";
} else if (isOfKind(type, TypeAttribute::ADT)) {
} else if (isBaseOfKind(type, TypeAttribute::ADT)) {
return "+";
} else {
fatal("Unsupported kind");
Expand All @@ -227,10 +247,13 @@ std::string getTypeQualifier(const Type& type) {
return tfm::format("%s:%s", kind, type.getName());
}

bool isSubtypeOf(const Type& a, const Type& b) {
assert(&a.getTypeEnvironment() == &b.getTypeEnvironment() &&
bool isSubtypeOf(const Type& ta, const Type& tb) {
assert(&ta.getTypeEnvironment() == &tb.getTypeEnvironment() &&
"Types must be in the same type environment");

auto& a = skipAliasesType(ta);
auto& b = skipAliasesType(tb);

if (isOfRootType(a, b)) {
return true;
}
Expand Down Expand Up @@ -263,10 +286,13 @@ void TypeEnvironment::print(std::ostream& out) const {
}
}

TypeSet getGreatestCommonSubtypes(const Type& a, const Type& b) {
assert(&a.getTypeEnvironment() == &b.getTypeEnvironment() &&
TypeSet getGreatestCommonSubtypes(const Type& ta, const Type& tb) {
assert(&ta.getTypeEnvironment() == &tb.getTypeEnvironment() &&
"Types must be in the same type environment");

auto& a = skipAliasesType(ta);
auto& b = skipAliasesType(tb);

if (isSubtypeOf(a, b)) {
return TypeSet(a);
}
Expand Down Expand Up @@ -365,7 +391,7 @@ bool haveCommonSupertype(const Type& a, const Type& b) {
TypeAttribute getTypeAttribute(const Type& type) {
for (auto typeAttribute : {TypeAttribute::Signed, TypeAttribute::Unsigned, TypeAttribute::Float,
TypeAttribute::Record, TypeAttribute::Symbol, TypeAttribute::ADT}) {
if (isOfKind(type, typeAttribute)) {
if (isOfKind(skipAliasesType(type), typeAttribute)) {
return typeAttribute;
}
}
Expand All @@ -391,12 +417,33 @@ bool isADTEnum(const AlgebraicDataType& type) {
}

const Type& getBaseType(const Type* type) {
while (auto subset = as<SubsetType>(type)) {
type = &subset->getBaseType();
};
if (auto subset = as<SubsetType>(type)) {
return getBaseType(&subset->getBaseType());
}

if (auto alias = as<AliasType>(type)) {
return getBaseType(&alias->getAliasType());
}

assert((isA<ConstantType>(type) || isA<RecordType>(type)) &&
"Root must be a constant type or a record type");
return *type;
}

const Type& skipAliasesType(const Type* type) {
if (auto alias = as<AliasType>(type)) {
return skipAliasesType(&alias->getAliasType());
}

return *type;
}

const Type& skipAliasesType(const Type& type) {
if (auto alias = as<AliasType>(type)) {
return skipAliasesType(alias->getAliasType());
}

return type;
}

} // namespace souffle::ast::analysis
14 changes: 13 additions & 1 deletion src/ast/analysis/typesystem/TypeSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -581,10 +581,15 @@ bool isSubtypeOf(const Type& a, const Type& b);
/** Returns fully qualified name for a given type */
std::string getTypeQualifier(const Type& type);

/** Check if the type is of a kind corresponding to the TypeAttribute */
/** Check if the type is of a kind corresponding to the TypeAttribute (does not traverse sub-type and
* alias-type if kind is ADT or Record) */
bool isOfKind(const Type& type, TypeAttribute kind);
bool isOfKind(const TypeSet& typeSet, TypeAttribute kind);

/** Check if the type is a direct or a sub-type or an alias-type of a kind corresponding to the TypeAttribute
*/
bool isBaseOfKind(const Type& type, TypeAttribute kind);

/** Get type attributes */
TypeAttribute getTypeAttribute(const Type&);

Expand Down Expand Up @@ -637,8 +642,15 @@ TypeSet getGreatestCommonSubtypes(const Types&... types) {
*/
bool haveCommonSupertype(const Type& a, const Type& b);

/** Return the base type of type, skipping aliases and ascending sub-types. */
const Type& getBaseType(const Type* type);

/** Return the un-aliased type of type. */
const Type& skipAliasesType(const Type* type);

/** Return the un-aliased type of type. */
const Type& skipAliasesType(const Type& type);

/**
* Determine if two types are equivalent.
* That is, check if a <: b and b <: a
Expand Down
60 changes: 38 additions & 22 deletions src/ast/transform/IOAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,28 +183,33 @@ class IOAttributesTransformer : public Transformer {

std::map<std::string, json11::Json> sumTypes;

visit(program.getTypes(), [&](const AlgebraicDataType& astAlgebraicDataType) {
auto& sumType = asAssert<analysis::AlgebraicDataType>(typeEnv.getType(astAlgebraicDataType));
auto& branches = sumType.getBranches();
for (auto* astType : program.getTypes()) {
const auto& type = typeEnv.getType(*astType);

if (isA<analysis::AlgebraicDataType>(skipAliasesType(type))) {
// resolve alias-type to adt
auto& sumType = asAssert<analysis::AlgebraicDataType>(skipAliasesType(type));
auto& branches = sumType.getBranches();

std::vector<json11::Json> branchesInfo;
std::vector<json11::Json> branchesInfo;

for (const auto& branch : branches) {
std::vector<json11::Json> branchTypes;
for (auto* type : branch.types) {
branchTypes.push_back(getTypeQualifier(*type));
for (const auto& branch : branches) {
std::vector<json11::Json> branchTypes;
for (auto* type : branch.types) {
branchTypes.push_back(getTypeQualifier(*type));
}

auto branchInfo = json11::Json::object{
{{"types", std::move(branchTypes)}, {"name", branch.name.toString()}}};
branchesInfo.push_back(std::move(branchInfo));
}

auto branchInfo = json11::Json::object{
{{"types", std::move(branchTypes)}, {"name", branch.name.toString()}}};
branchesInfo.push_back(std::move(branchInfo));
auto typeQualifier = analysis::getTypeQualifier(type);
auto&& sumInfo = json11::Json::object{{{"branches", std::move(branchesInfo)},
{"arity", static_cast<long long>(branches.size())}, {"enum", isADTEnum(sumType)}}};
sumTypes.emplace(std::move(typeQualifier), std::move(sumInfo));
}

auto typeQualifier = analysis::getTypeQualifier(sumType);
auto&& sumInfo = json11::Json::object{{{"branches", std::move(branchesInfo)},
{"arity", static_cast<long long>(branches.size())}, {"enum", isADTEnum(sumType)}}};
sumTypes.emplace(std::move(typeQualifier), std::move(sumInfo));
});
}

sumTypesInfo = json11::Json(sumTypes);
return sumTypesInfo;
Expand All @@ -225,10 +230,11 @@ class IOAttributesTransformer : public Transformer {
// Iterate over all record types in the program populating the records map.
for (auto* astType : program.getTypes()) {
const auto& type = typeEnv.getType(*astType);
if (isA<analysis::RecordType>(type)) {
if (isA<analysis::RecordType>(skipAliasesType(type))) {
elementTypes.clear();

for (const analysis::Type* field : as<analysis::RecordType>(type)->getFields()) {
for (const analysis::Type* field :
as<analysis::RecordType>(skipAliasesType(type))->getFields()) {
elementTypes.push_back(getTypeQualifier(*field));
}
const std::size_t recordArity = elementTypes.size();
Expand All @@ -250,15 +256,25 @@ class IOAttributesTransformer : public Transformer {
}

Program& program = translationUnit.getProgram();
auto& typeEnv = translationUnit.getAnalysis<analysis::TypeEnvironmentAnalysis>().getTypeEnvironment();
std::vector<std::string> elementParams;
std::map<std::string, json11::Json> records;

// Iterate over all record types in the program populating the records map.
for (auto* astType : program.getTypes()) {
if (isA<ast::RecordType>(astType)) {
const auto programTypes = program.getTypes();
for (auto* astType : programTypes) {
// if the ast type is an alias, we have to traverse it:
const auto& unaliasedType = skipAliasesType(typeEnv.getType(*astType));

if (isA<analysis::RecordType>(unaliasedType)) {
elementParams.clear();

for (const auto field : as<ast::RecordType>(astType)->getFields()) {
// find the ast type associated with the unaliased type
const auto unaliasedAstType = std::find_if(programTypes.begin(), programTypes.end(),
[&](Type* x) { return x->getQualifiedName() == unaliasedType.getName(); });

// list the fields from the unaliased type
for (const auto field : as<ast::RecordType>(*unaliasedAstType)->getFields()) {
elementParams.push_back(field->getName());
}
const std::size_t recordArity = elementParams.size();
Expand Down
2 changes: 1 addition & 1 deletion src/ast/transform/TypeChecker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ void TypeCheckerImpl::visit_(type_identity<RecordInit>, const RecordInit& rec) {
}

// At this point we know that there is exactly one type in set, so we can take it.
auto& recordType = *as<analysis::RecordType>(*types.begin());
auto& recordType = *as<analysis::RecordType>(getBaseType(&*types.begin()));

if (recordType.getFields().size() != rec.getArguments().size()) {
report.addError("Wrong number of arguments given to record", rec.getSrcLoc());
Expand Down
3 changes: 3 additions & 0 deletions tests/semantic/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,6 @@ positive_test(pragma1)
souffle_run_test_helper(TEST_NAME pragma2 FUNCTORS CATEGORY semantic)
positive_test(rel_redundant)
positive_test(type_as4)
positive_test(record_alias)
positive_test(record_alias2)
positive_test(adt_alias)
2 changes: 2 additions & 0 deletions tests/semantic/adt_alias/R.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
$end
$next($end, 1)
8 changes: 8 additions & 0 deletions tests/semantic/adt_alias/adt_alias.dl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.type L= next {a:L, b:number} | end {}
.type A=L

.decl R(x:A)
R($end()).
R($next($end(),1)).

.output R
Empty file.
Empty file.
2 changes: 2 additions & 0 deletions tests/semantic/record_alias/R.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
nil
[nil, 1]
8 changes: 8 additions & 0 deletions tests/semantic/record_alias/record_alias.dl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.type L=[a:L, b:number]
.type A=L

.decl R(x:A)
R(nil).
R([nil,1]).

.output R
Empty file.
Empty file.
Loading