From 4091576ab7672ceb5bdebebb5651cc8877d3389a Mon Sep 17 00:00:00 2001 From: metagn Date: Wed, 30 Oct 2024 10:58:04 +0300 Subject: [PATCH] implement generic default values for object fields (#24384) fixes #21941, fixes #23594 --- compiler/sem.nim | 1 + compiler/semdata.nim | 1 + compiler/semtypes.nim | 55 +++++++++++++++----------- compiler/semtypinst.nim | 5 +++ tests/objects/tgenericdefaultvalue.nim | 26 ++++++++++++ 5 files changed, 65 insertions(+), 23 deletions(-) create mode 100644 tests/objects/tgenericdefaultvalue.nim diff --git a/compiler/sem.nim b/compiler/sem.nim index 2ff4d758d2eb..cb0e5590cea8 100644 --- a/compiler/sem.nim +++ b/compiler/sem.nim @@ -733,6 +733,7 @@ proc preparePContext*(graph: ModuleGraph; module: PSym; idgen: IdGenerator): PCo result.semInferredLambda = semInferredLambda result.semGenerateInstance = generateInstance result.instantiateOnlyProcType = instantiateOnlyProcType + result.fitDefaultNode = fitDefaultNode result.semTypeNode = semTypeNode result.instTypeBoundOp = sigmatch.instTypeBoundOp result.hasUnresolvedArgs = hasUnresolvedArgs diff --git a/compiler/semdata.nim b/compiler/semdata.nim index 688cd97009c8..6e256b3d322e 100644 --- a/compiler/semdata.nim +++ b/compiler/semdata.nim @@ -142,6 +142,7 @@ type instantiateOnlyProcType*: proc (c: PContext, pt: LayeredIdTable, prc: PSym, info: TLineInfo): PType # used by sigmatch for explicit generic instantiations + fitDefaultNode*: proc (c: PContext, n: var PNode, expectedType: PType) includedFiles*: IntSet # used to detect recursive include files pureEnumFields*: TStrTable # pure enum fields that can be used unambiguously userPragmas*: TStrTable diff --git a/compiler/semtypes.nim b/compiler/semtypes.nim index 757076f850f2..54e7298d4214 100644 --- a/compiler/semtypes.nim +++ b/compiler/semtypes.nim @@ -273,28 +273,23 @@ proc annotateClosureConv(n: PNode) = for i in 0.. 0: + a[^1] = semExprWithType(c, a[^1], {efDetermineType, efAllowSymChoice}, typ) + if typ == nil: + typ = a[^1].typ + else: + fitDefaultNode(c, a[^1], typ) + typ = a[^1].typ elif a[^2].kind != nkEmpty: typ = semTypeNode(c, a[^2], nil) if c.graph.config.isDefined("nimPreviewRangeDefault") and typ.skipTypes(abstractInst).kind == tyRange: @@ -885,8 +887,15 @@ proc semRecordNodeAux(c: PContext, n: PNode, check: var IntSet, pos: var int, var typ: PType var hasDefaultField = n[^1].kind != nkEmpty if hasDefaultField: - typ = fitDefaultNode(c, n) - propagateToOwner(rectype, typ) + typ = if n[^2].kind != nkEmpty: semTypeNode(c, n[^2], nil) else: nil + if c.inGenericContext > 0: + n[^1] = semExprWithType(c, n[^1], {efDetermineType, efAllowSymChoice}, typ) + if typ == nil: + typ = n[^1].typ + else: + fitDefaultNode(c, n[^1], typ) + typ = n[^1].typ + propagateToOwner(rectype, typ) elif n[^2].kind == nkEmpty: localError(c.config, n.info, errTypeExpected) typ = errorType(c) diff --git a/compiler/semtypinst.nim b/compiler/semtypinst.nim index 6a7d140cc869..2aca2c7a4baf 100644 --- a/compiler/semtypinst.nim +++ b/compiler/semtypinst.nim @@ -276,6 +276,11 @@ proc replaceTypeVarsN(cl: var TReplTypeVars, n: PNode; start=0; expectedType: PT replaceTypeVarsS(cl, n.sym, result.typ) else: replaceTypeVarsS(cl, n.sym, replaceTypeVarsT(cl, n.sym.typ)) + if result.sym.kind == skField and result.sym.ast != nil and + (cl.owner == nil or result.sym.owner == cl.owner): + # instantiate default value of object/tuple field + cl.c.fitDefaultNode(cl.c, result.sym.ast, result.sym.typ) + result.sym.typ = result.sym.ast.typ # sym type can be nil if was gensym created by macro, see #24048 if result.sym.typ != nil and result.sym.typ.kind == tyVoid: # don't add the 'void' field diff --git a/tests/objects/tgenericdefaultvalue.nim b/tests/objects/tgenericdefaultvalue.nim new file mode 100644 index 000000000000..f665fa3a4f9e --- /dev/null +++ b/tests/objects/tgenericdefaultvalue.nim @@ -0,0 +1,26 @@ +block: # issue #23594 + type + Gen[T] = object + a: T = 1.0 + + Spec32 = Gen[float32] + Spec64 = Gen[float64] + + var + a: Spec32 + b: Spec64 + doAssert sizeof(a) == 4 + doAssert sizeof(b) == 8 + doAssert a.a is float32 + doAssert b.a is float64 + +block: # issue #21941 + func what[T](): T = + 123 + + type MyObject[T] = object + f: T = what[T]() + + var m: MyObject[float] = MyObject[float]() + doAssert m.f is float + doAssert m.f == 123.0