Skip to content

Commit

Permalink
Add a warning for generic superclasses that don't include (?) (#22784)
Browse files Browse the repository at this point in the history
This PR takes two steps:

1. It makes it legal to write `class C : GenericParent(?)` in order to
clearly indicate that `class C` is generic due to inheriting from a
generic class.
2. It adds a warning to request this form when inheriting from a generic
class, and a related error for the case of `class C :
ConcreteParent(?)`.

The rationale for the warning is that, if code compiling without these
warnings (including the field warning from PRs #22745 and #22697) then
it is syntactically apparent whether or not a class or record is
generic. That means that, in the future, if we make these errors, a) the
compiler doesn't have to work as hard to know if a type is generic and
b) neither do users.

Reviewed by @vasslitvinov - thanks!

- [x] full comm=none testing
  • Loading branch information
mppf authored Jul 27, 2023
2 parents 343b786 + 2a99b8a commit ec7d920
Show file tree
Hide file tree
Showing 93 changed files with 241 additions and 132 deletions.
9 changes: 8 additions & 1 deletion compiler/passes/convert-uast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3806,9 +3806,13 @@ struct Converter {
const char* name = astr(node->name());
const char* cname = name;
Expr* inherit = nullptr;
bool inheritMarkedGeneric = false;

if (auto cls = node->toClass()) {
inherit = convertExprOrNull(cls->parentClass());
const uast::Identifier* ident =
uast::Class::getInheritExprIdent(cls->parentClass(),
inheritMarkedGeneric);
inherit = convertExprOrNull(ident);
}

if (node->linkageName()) {
Expand All @@ -3833,6 +3837,9 @@ struct Converter {

attachSymbolAttributes(node, ret->sym);
attachSymbolVisibility(node, ret->sym);
if (inheritMarkedGeneric) {
ret->sym->addFlag(FLAG_SUPERCLASS_MARKED_GENERIC);
}

// Note the type is converted so we can wire up SymExprs later
noteConvertedSym(node, ret->sym);
Expand Down
27 changes: 24 additions & 3 deletions compiler/passes/scopeResolve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,33 @@ static void markGenerics() {
} while (changed);
}

static void checkClass(AggregateType* ct) {
if (isClass(ct) && ct->symbol->hasFlag(FLAG_EXTERN)) {
USR_FATAL_CONT(ct, "Extern classes are not supported.");
}
// Warn for superclass should be marked generic
// Error for a concrete superclass that is marked generic
if (isClass(ct) && ct->dispatchParents.n == 1) {
if (AggregateType* parent = ct->dispatchParents.v[0]) {
if (isClass(parent)) {
if (!ct->symbol->hasFlag(FLAG_SUPERCLASS_MARKED_GENERIC) &&
parent->isGeneric() && !parent->isGenericWithDefaults()) {
USR_WARN(ct->symbol, "missing '(?)' after a generic parent class");
}
if (ct->symbol->hasFlag(FLAG_SUPERCLASS_MARKED_GENERIC) &&
!parent->isGeneric()) {
USR_FATAL(ct->symbol, "'(?)' after a concrete parent class");
}
}
}
}
}

static void processGenericFields() {
forv_Vec(AggregateType, ct, gAggregateTypes) {
// Do some checks now that generic-ness is settled
checkClass(ct);
// Build the type constructor now that we know which types are generic
if (isClass(ct) && ct->symbol->hasFlag(FLAG_EXTERN)) {
USR_FATAL_CONT(ct, "Extern classes are not supported.");
}
ct->processGenericFields();
}
}
Expand Down
8 changes: 7 additions & 1 deletion frontend/include/chpl/uast/Class.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class Class final : public AggregateDecl {
numElements),
parentClassChildNum_(parentClassChildNum) {
CHPL_ASSERT(parentClassChildNum_ == NO_CHILD ||
child(parentClassChildNum_)->isIdentifier());
isAcceptableInheritExpr(child(parentClassChildNum_)));
}

Class(Deserializer& des)
Expand Down Expand Up @@ -111,6 +111,12 @@ class Class final : public AggregateDecl {

DECLARE_STATIC_DESERIALIZE(Class);

/** Returns the inherited Identifier, including considering
one marked generic with Superclass(?) */
static const Identifier* getInheritExprIdent(const AstNode* ast,
bool& markedGeneric);
/** Returns true if the passed inherit expression is legal */
static bool isAcceptableInheritExpr(const AstNode* ast);
};


Expand Down
3 changes: 2 additions & 1 deletion frontend/include/chpl/uast/PragmaList.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,8 @@ PRAGMA(LOCAL_ON, npr, "local on", ncm)
PRAGMA(LOOP_BODY_ARGUMENT_CLASS, npr, "loop body argument class", ncm)
PRAGMA(MANAGED_POINTER, ypr, "managed pointer", "e.g. Owned and Shared")
PRAGMA(MANAGED_POINTER_NONNILABLE, npr, "managed pointer nonnilable", "e.g. non-nilable Owned and Shared")
PRAGMA(MARKED_GENERIC, npr, "marked generic", "formal is marked generic using the type query syntax")
PRAGMA(MARKED_GENERIC, npr, "marked generic", "marked generic using the type query syntax")
PRAGMA(SUPERCLASS_MARKED_GENERIC, npr, "supreclass marked generic", "superclass is marked generic")
PRAGMA(MAYBE_ARRAY_TYPE, npr, "maybe array type", "function may be computing array type")
PRAGMA(MAYBE_COPY_ELIDED, npr, "maybe copy elided", "symbol might be dead early due to copy elision")
PRAGMA(MAYBE_PARAM, npr, "maybe param", "symbol can resolve to a param")
Expand Down
29 changes: 25 additions & 4 deletions frontend/lib/parsing/ParserContextImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2498,7 +2498,7 @@ ParserContext::buildAggregateTypeDecl(YYLTYPE location,

auto contentsList = consumeList(contents);

owned<Identifier> inheritIdentifier;
owned<AstNode> inheritExpr;
if (optInherit != nullptr) {
if (optInherit->size() > 0) {
if (parts.tag == asttags::Record) {
Expand All @@ -2510,11 +2510,32 @@ ParserContext::buildAggregateTypeDecl(YYLTYPE location,
if (optInherit->size() > 1)
error(inheritLoc, "only single inheritance is supported.");
AstNode* ast = (*optInherit)[0];
bool inheritOk = false;
if (ast->isIdentifier()) {
inheritIdentifier = toOwned(ast->toIdentifier());
// inheriting from e.g. Parent is OK
inheritOk = true;
} else {
// inheriting from e.g. Parent(?) is OK
if (auto call = ast->toFnCall()) {
const AstNode* calledExpr = call->calledExpression();
if (calledExpr != nullptr && call->numActuals() == 1) {
if (const AstNode* actual = call->actual(0)) {
if (auto id = actual->toIdentifier()) {
if (id->name() == USTR("?")) {
inheritOk = true;
}
}
}
}
}
}

if (inheritOk) {
inheritExpr = toOwned(ast);
(*optInherit)[0] = nullptr;
} else {
syntax(inheritLoc, "non-Identifier expression cannot be inherited.");
syntax(inheritLoc,
"invalid parent class; please specify a single class name");
}
}
}
Expand All @@ -2534,7 +2555,7 @@ ParserContext::buildAggregateTypeDecl(YYLTYPE location,
toOwned(parts.attributeGroup),
parts.visibility,
parts.name,
std::move(inheritIdentifier),
std::move(inheritExpr),
std::move(contentsList)).release();
} else if (parts.tag == asttags::Record) {
decl = Record::build(builder, convertLocation(location),
Expand Down
8 changes: 5 additions & 3 deletions frontend/lib/resolution/Resolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,11 +427,13 @@ gatherParentClassScopesForScopeResolving(Context* context, ID classDeclId) {
// Intended to avoid calling methodReceiverScopes() recursively.
// Uses the empty 'savecReceiverScopes' because the class expression
// can't be a method anyways.
visitor.resolveIdentifier(parentClassExpr->toIdentifier(),
visitor.savedReceiverScopes);
bool ignoredMarkedGeneric = false;
auto ident = Class::getInheritExprIdent(parentClassExpr,
ignoredMarkedGeneric);
visitor.resolveIdentifier(ident, visitor.savedReceiverScopes);


ResolvedExpression& re = r.byAst(parentClassExpr);
ResolvedExpression& re = r.byAst(ident);
if (re.toId().isEmpty()) {
context->error(parentClassExpr, "invalid parent class expression");
} else {
Expand Down
35 changes: 35 additions & 0 deletions frontend/lib/uast/Class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include "chpl/uast/Class.h"

#include "chpl/uast/Builder.h"
#include "chpl/uast/FnCall.h"
#include "chpl/uast/Identifier.h"

namespace chpl {
namespace uast {
Expand Down Expand Up @@ -70,6 +72,39 @@ owned<Class> Class::build(Builder* builder, Location loc,
return toOwned(ret);
}

const Identifier* Class::getInheritExprIdent(const AstNode* ast,
bool& markedGeneric) {
if (ast != nullptr) {
if (ast->isIdentifier()) {
// inheriting from e.g. Parent is OK
markedGeneric = false;
return ast->toIdentifier();
} else if (auto call = ast->toFnCall()) {
const AstNode* calledExpr = call->calledExpression();
if (calledExpr != nullptr && calledExpr->isIdentifier() &&
call->numActuals() == 1) {
if (const AstNode* actual = call->actual(0)) {
if (auto id = actual->toIdentifier()) {
if (id->name() == USTR("?")) {
// inheriting from e.g. Parent(?) is OK
markedGeneric = true;
return calledExpr->toIdentifier();
}
}
}
}
}
}

markedGeneric = false;
return nullptr;
}

bool Class::isAcceptableInheritExpr(const AstNode* ast) {
bool ignoredMarkedGeneric = false;
return getInheritExprIdent(ast, ignoredMarkedGeneric) != nullptr;
}


} // namespace uast
} // namespace chpl
4 changes: 2 additions & 2 deletions modules/dists/BlockCycDist.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ proc LocBlockCyclic.writeThis(x) throws {
////////////////////////////////////////////////////////////////////////////////
// BlockCyclic Domain Class
//
class BlockCyclicDom: BaseRectangularDom {
class BlockCyclicDom: BaseRectangularDom(?) {
//
// LEFT LINK: a pointer to the parent distribution
//
Expand Down Expand Up @@ -845,7 +845,7 @@ proc LocBlockCyclicDom._sizes {
////////////////////////////////////////////////////////////////////////////////
// BlockCyclic Array Class
//
class BlockCyclicArr: BaseRectangularArr {
class BlockCyclicArr: BaseRectangularArr(?) {

//
// LEFT LINK: the global domain descriptor for this array
Expand Down
4 changes: 2 additions & 2 deletions modules/dists/BlockDist.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ class LocBlock {
// locDoms: a non-distributed array of local domain classes
// whole: a non-distributed domain that defines the domain's indices
//
class BlockDom: BaseRectangularDom {
class BlockDom: BaseRectangularDom(?) {
type sparseLayoutType;
const dist: unmanaged Block(rank, idxType, sparseLayoutType);
var locDoms: [dist.targetLocDom] unmanaged LocBlockDom(rank, idxType, strides);
Expand Down Expand Up @@ -388,7 +388,7 @@ class LocBlockDom {
// locArr: a non-distributed array of local array classes
// myLocArr: optimized reference to here's local array class (or nil)
//
class BlockArr: BaseRectangularArr {
class BlockArr: BaseRectangularArr(?) {
type sparseLayoutType;
var doRADOpt: bool = defaultDoRADOpt;
var dom: unmanaged BlockDom(rank, idxType, strides, sparseLayoutType);
Expand Down
4 changes: 2 additions & 2 deletions modules/dists/CyclicDist.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ class LocCyclic {
}


class CyclicDom : BaseRectangularDom {
class CyclicDom : BaseRectangularDom(?) {
const dist: unmanaged Cyclic(rank, idxType);

var locDoms: [dist.targetLocDom] unmanaged LocCyclicDom(rank, idxType);
Expand Down Expand Up @@ -736,7 +736,7 @@ private proc myBlockType(param rank, type idxType) type do
proc LocCyclicDom.contains(i) do return myBlock.contains(i);


class CyclicArr: BaseRectangularArr {
class CyclicArr: BaseRectangularArr(?) {
var doRADOpt: bool = defaultDoRADOpt;
var dom: unmanaged CyclicDom(rank, idxType, strides);

Expand Down
4 changes: 2 additions & 2 deletions modules/dists/DimensionalDist2D.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ private proc locDescTypeHelper(param rank : int, type idxType, dom1, dom2) type
return unmanaged LocDimensionalDom(domain(rank, idxType, str), d1type, d2type);
}

class DimensionalDom : BaseRectangularDom {
class DimensionalDom : BaseRectangularDom(?) {
// required
const dist; // not reprivatized

Expand Down Expand Up @@ -351,7 +351,7 @@ class LocDimensionalDom {
var doml1, doml2;
}

class DimensionalArr : BaseRectangularArr {
class DimensionalArr : BaseRectangularArr(?) {
// required
const dom; // must be a DimensionalDom

Expand Down
2 changes: 1 addition & 1 deletion modules/dists/HashedDist.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ class LocUserMapAssocDom {
//
// the global array class
//
class UserMapAssocArr: AbsBaseArr {
class UserMapAssocArr: AbsBaseArr(?) {
// GENERICS:

//
Expand Down
5 changes: 3 additions & 2 deletions modules/dists/PrivateDist.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class Private: BaseDist {
override proc singleton() param do return true;
}

class PrivateDom: BaseRectangularDom {
class PrivateDom: BaseRectangularDom(?) {
var dist: unmanaged Private;

iter these() { for i in 0..numLocales-1 do yield i; }
Expand Down Expand Up @@ -170,7 +170,8 @@ private proc checkCanMakeDefaultValue(type eltType) param {
var default: eltType;
}

class PrivateArr: BaseRectangularArr {
class PrivateArr: BaseRectangularArr(?) {

var dom: unmanaged PrivateDom(rank, idxType, strides);

// may be initialized separately
Expand Down
4 changes: 2 additions & 2 deletions modules/dists/ReplicatedDist.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ proc Replicated.dsiPrivatize(privatizeData)
//
// global domain class
//
class ReplicatedDom : BaseRectangularDom {
class ReplicatedDom : BaseRectangularDom(?) {
// we need to be able to provide the domain map for our domain - to build its
// runtime type (because the domain map is part of the type - for any domain)
// (looks like it must be called exactly 'dist')
Expand Down Expand Up @@ -386,7 +386,7 @@ proc ReplicatedDom.dsiAssignDomain(rhs: domain, lhsPrivate:bool) {
//
// global array class
//
class ReplicatedArr : AbsBaseArr {
class ReplicatedArr : AbsBaseArr(?) {
// These two are hard-coded in the compiler - it computes the array's
// type string as '[dom.type] eltType.type'
const dom; // must be a ReplicatedDom
Expand Down
4 changes: 2 additions & 2 deletions modules/dists/SparseBlockDist.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ record TargetLocaleComparator {
// locDoms: a non-distributed array of local domain classes
// whole: a non-distributed domain that defines the domain's indices
//
class SparseBlockDom: BaseSparseDomImpl {
class SparseBlockDom: BaseSparseDomImpl(?) {
type sparseLayoutType;
param strides = strideKind.one; // TODO: remove default value eventually
const dist: unmanaged Block(rank, idxType, sparseLayoutType);
Expand Down Expand Up @@ -387,7 +387,7 @@ class LocSparseBlockDom {
// locArr: a non-distributed array of local array classes
// myLocArr: optimized reference to here's local array class (or nil)
//
class SparseBlockArr: BaseSparseArr {
class SparseBlockArr: BaseSparseArr(?) {
param strides: strideKind;
type sparseLayoutType = unmanaged DefaultDist;

Expand Down
4 changes: 2 additions & 2 deletions modules/dists/StencilDist.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ class LocStencil {
// locDoms: a non-distributed array of local domain classes
// whole: a non-distributed domain that defines the domain's indices
//
class StencilDom: BaseRectangularDom {
class StencilDom: BaseRectangularDom(?) {
param ignoreFluff : bool;
const dist: unmanaged Stencil(rank, idxType, ignoreFluff);
var locDoms: [dist.targetLocDom] unmanaged LocStencilDom(rank, idxType,
Expand Down Expand Up @@ -329,7 +329,7 @@ class LocStencilDom {
// locArr: a non-distributed array of local array classes
// myLocArr: optimized reference to here's local array class (or nil)
//
class StencilArr: BaseRectangularArr {
class StencilArr: BaseRectangularArr(?) {
param ignoreFluff: bool;
var doRADOpt: bool = defaultDoRADOpt;
var dom: unmanaged StencilDom(rank, idxType, strides, ignoreFluff);
Expand Down
4 changes: 2 additions & 2 deletions modules/internal/ArrayViewRankChange.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ module ArrayViewRankChange {
// for rectangular domains (because they're the only ones with
// rank>1), so this is a subclass of BaseRectangularDom.
//
class ArrayViewRankChangeDom: BaseRectangularDom {
class ArrayViewRankChangeDom: BaseRectangularDom(?) {
// the lower-dimensional index set that we represent upwards
var upDomInst: unmanaged DefaultRectangularDom(rank, idxType, strides)?;
forwarding upDom except these, chpl__serialize, chpl__deserialize;
Expand Down Expand Up @@ -468,7 +468,7 @@ module ArrayViewRankChange {
// interface.
//
pragma "aliasing array"
class ArrayViewRankChangeArr: AbsBaseArr {
class ArrayViewRankChangeArr: AbsBaseArr(?) {
// the representation of the slicing domain. For a rank change
// like A[lo..hi, 3] this is the lower-dimensional domain {lo..hi}.
// It is represented as an ArrayViewRankChangeDom.
Expand Down
Loading

0 comments on commit ec7d920

Please sign in to comment.