Skip to content

Commit

Permalink
feat: implement generic functions
Browse files Browse the repository at this point in the history
Signed-off-by: ganjing <[email protected]>
  • Loading branch information
Shanks0224 committed Sep 28, 2023
1 parent b18c7dd commit 3ac20dd
Show file tree
Hide file tree
Showing 11 changed files with 534 additions and 143 deletions.
66 changes: 1 addition & 65 deletions src/backend/binaryen/wasm_expr_gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1782,71 +1782,7 @@ export class WASMExpressionGen {
args,
funcDecl,
);
let specializedFuncName = funcName;
/* If a function is a generic function, we need to generate a specialized type function here */
if (funcDecl && funcDecl.funcType.typeArguments) {
/* record the original information */
const oriFuncType = funcDecl.funcType;
const oriFuncCtx = this.wasmCompiler.currentFuncCtx;
const oriFuncParams = funcDecl.parameters;
const oriFuncVars = funcDecl.varList;
/* change typeArgument to the specialize version */
funcDecl.funcType = funcType;
if (!funcType.specialTypeArguments) {
throw new Error('not recorded the specialized type yet');
}
let specializedSuffix = '';
for (const specializedTypeArg of funcType.specialTypeArguments!) {
specializedSuffix = specializedSuffix.concat(
'_',
specializedTypeArg.typeId.toString(),
);
}
specializedFuncName = funcName.concat(specializedSuffix);
funcDecl.name = specializedFuncName;
if (funcDecl.parameters) {
for (const p of funcDecl.parameters) {
if (
p.type instanceof TypeParameterType ||
p.type instanceof ValueTypeWithArguments
) {
this.specializeType(p.type, funcType);
}
}
}
if (funcDecl.varList) {
for (const v of funcDecl.varList) {
if (
v.type instanceof TypeParameterType ||
v.type instanceof ValueTypeWithArguments
) {
this.specializeType(v.type, funcType);
}
}
}

this.wasmCompiler.parseFunc(funcDecl);
/* restore the information */
this.wasmCompiler.currentFuncCtx = oriFuncCtx;
funcDecl.name = funcName;
funcDecl.funcType = oriFuncType;
funcDecl.parameters = oriFuncParams;
funcDecl.varList = oriFuncVars;
}
return this.module.call(specializedFuncName, callArgsRefs, returnType);
}

private specializeType(
type: TypeParameterType | ValueTypeWithArguments,
root: ValueTypeWithArguments,
) {
if (type instanceof TypeParameterType) {
const specialType = root.getSpecialTypeArg(type)!;
type.setSpecialTypeArgument(specialType);
} else {
const specTypeArgs = root.getSpecialTypeArgs(type.typeArguments!);
type.setSpecialTypeArguments(specTypeArgs);
}
return this.module.call(funcName, callArgsRefs, returnType);
}

private wasmObjFieldSet(
Expand Down
79 changes: 75 additions & 4 deletions src/expression.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@ import ts from 'typescript';
import { ParserContext } from './frontend.js';
import { ClosureEnvironment, FunctionScope } from './scope.js';
import { Variable } from './variable.js';
import { getCurScope, addSourceMapLoc } from './utils.js';
import { Type, TypeKind, builtinTypes } from './type.js';
import { getCurScope, addSourceMapLoc, isTypeGeneric } from './utils.js';
import {
TSFunction,
Type,
TypeKind,
TypeResolver,
builtinTypes,
} from './type.js';
import { Logger } from './log.js';
import { SourceMapLoc } from './backend/binaryen/utils.js';
import { ExpressionError } from './error.js';
Expand Down Expand Up @@ -514,11 +520,13 @@ export default class ExpressionProcessor {
const varReferenceScope = scope!.getNearestFunctionScope();
let variable: Variable | undefined = undefined;
let maybeClosureVar = false;
let exprType: Type = builtinTypes.get('undefined')!;

if (varReferenceScope) {
while (scope) {
variable = scope.findVariable(targetIdentifier, false);
if (variable) {
exprType = variable.varType;
break;
}

Expand All @@ -544,8 +552,9 @@ export default class ExpressionProcessor {
this.parserCtx.typeChecker!.getSymbolAtLocation(node);
if (symbol && symbol.valueDeclaration) {
declNode = symbol.valueDeclaration;
exprType = this.typeResolver.generateNodeType(declNode);
}
res.setExprType(this.typeResolver.generateNodeType(declNode));
res.setExprType(exprType);
break;
}
case ts.SyntaxKind.BinaryExpression: {
Expand Down Expand Up @@ -618,7 +627,7 @@ export default class ExpressionProcessor {
}
case ts.SyntaxKind.CallExpression: {
const callExprNode = <ts.CallExpression>node;
const expr = this.visitNode(callExprNode.expression);
let expr = this.visitNode(callExprNode.expression);
const args = new Array<Expression>(
callExprNode.arguments.length,
);
Expand All @@ -632,6 +641,68 @@ export default class ExpressionProcessor {
res.setExprType(this.typeResolver.generateNodeType(node));
break;
}

// iff a generic function is specialized and called
const origType = this.typeResolver.generateNodeType(
callExprNode.expression,
);
if (
isTypeGeneric(origType) &&
callExprNode.expression.kind === ts.SyntaxKind.Identifier
) {
// the function name of the CallExpression is corrected to the specialized function name
let typeArguments: Type[] | undefined;

// explicitly declare specialization type typeArguments
// e.g.
// function genericFunc<T> (v: T){...}
// genericFunc<number>(5);
if (callExprNode.typeArguments) {
typeArguments = callExprNode.typeArguments.map((t) => {
return this.typeResolver.generateNodeType(t);
});
}
// specialize by passing parameters
// e.g.
// function genericFunc<T> (v: T){...}
// genericFunc('hello');
if (!typeArguments) {
typeArguments = callExprNode.arguments.map((t) => {
return this.typeResolver.generateNodeType(t);
});
}

if (typeArguments) {
let genericInheritance = false;
typeArguments.forEach((t) => {
if (isTypeGeneric(t)) {
genericInheritance = true;
}
});

if (!genericInheritance) {
const typeNames = new Array<string>();
typeArguments.forEach((v) => {
typeNames.push(`${v.kind}`);
});
const newIdentifierName =
(expr as IdentifierExpression).identifierName +
'<' +
typeNames.join(',') +
'>';
expr = new IdentifierExpression(newIdentifierName);

// the function type of the CallExpression is corrected to the specialized function type
const specializedType =
this.parserCtx.currentScope!.findIdentifier(
newIdentifierName,
);
if (specializedType)
expr.setExprType(specializedType as Type);
}
}
}

const callExpr = new CallExpression(
expr,
args,
Expand Down
40 changes: 16 additions & 24 deletions src/scope.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ export class Scope {
private localIndex = -1;
public mangledName = '';
private modifiers: ts.Node[] = [];
// iff this Scope is specialized
private _genericOwner?: Scope;

constructor(parent: Scope | null) {
this.parent = parent;
Expand Down Expand Up @@ -179,6 +181,14 @@ export class Scope {
this.modifiers.push(modifier);
}

setGenericOwner(genericOwner: Scope) {
this._genericOwner = genericOwner;
}

get genericOwner(): Scope | undefined {
return this._genericOwner;
}

protected _nestFindScopeItem<T>(
name: string,
searchFunc: (scope: Scope) => T | undefined,
Expand Down Expand Up @@ -521,6 +531,7 @@ export class Scope {
scope.localIndex = this.localIndex;
scope.mangledName = this.mangledName;
scope.modifiers = this.modifiers;
if (this.genericOwner) scope.setGenericOwner(this.genericOwner);
}

// deep copy
Expand All @@ -537,6 +548,7 @@ export class Scope {
scope.localIndex = this.localIndex;
scope.mangledName = this.mangledName;
scope.modifiers = this.modifiers;
if (this.genericOwner) scope.setGenericOwner(this.genericOwner);
}
}

Expand Down Expand Up @@ -671,12 +683,10 @@ export class FunctionScope extends ClosureEnvironment {
/* ori func name iff func is declare */
oriFuncName: string | undefined = undefined;
debugLocations: SourceMapLoc[] = [];
// iff this FunctionScope is specialized
private _genericOwner?: FunctionScope;

constructor(parent: Scope) {
constructor(parent: Scope | null) {
super(parent);
this.debugFilePath = parent.debugFilePath;
if (parent) this.debugFilePath = parent.debugFilePath;
}

getThisIndex() {
Expand Down Expand Up @@ -726,14 +736,6 @@ export class FunctionScope extends ClosureEnvironment {
return this._className !== '';
}

setGenericOwner(genericOwner: FunctionScope) {
this._genericOwner = genericOwner;
}

get genericOwner(): FunctionScope | undefined {
return this._genericOwner;
}

copy(funcScope: FunctionScope) {
super.copy(funcScope);
funcScope.kind = this.kind;
Expand Down Expand Up @@ -782,13 +784,11 @@ export class BlockScope extends ClosureEnvironment {
export class ClassScope extends Scope {
kind = ScopeKind.ClassScope;
private _classType: TSClass = new TSClass();
// iff this ClassScope is specialized
private _genericOwner?: ClassScope;

constructor(parent: Scope, name = '') {
constructor(parent: Scope | null, name = '') {
super(parent);
this.name = name;
this.debugFilePath = parent.debugFilePath;
if (parent) this.debugFilePath = parent.debugFilePath;
}

get className(): string {
Expand All @@ -803,14 +803,6 @@ export class ClassScope extends Scope {
return this._classType;
}

setGenericOwner(genericOwner: ClassScope) {
this._genericOwner = genericOwner;
}

get genericOwner(): ClassScope | undefined {
return this._genericOwner;
}

copy(classScope: ClassScope) {
super.copy(classScope);
classScope.kind = this.kind;
Expand Down
32 changes: 28 additions & 4 deletions src/semantics/expression_builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -889,12 +889,14 @@ function buildArrayLiteralExpression(
element_type = (<ArrayType>array_type).element;
}

let initValue_type: ValueType | undefined;
for (const element of expr.arrayValues) {
context.pushReference(ValueReferenceKind.RIGHT);
let v = buildExpression(element, context);
if (element_type != undefined) {
v = newCastValue(element_type, v);
}
initValue_type = v.type;
context.popReference();
init_values.push(v);
// if v is SpreadValue, add it's elem-type to init_types
Expand All @@ -921,6 +923,14 @@ function buildArrayLiteralExpression(

// return new NewLiteralArrayValue(array_type!, init_values);
const elem_type = (array_type as ArrayType).element;
// Workaround: solve the case that the return value of a generic function is a generic array
if (
elem_type.kind == ValueTypeKind.TYPE_PARAMETER &&
initValue_type &&
initValue_type.kind != ValueTypeKind.TYPE_PARAMETER
) {
array_type = createArrayType(context, initValue_type);
}
return new NewLiteralArrayValue(
array_type!,
expr.arrayValues.length == 0
Expand Down Expand Up @@ -1276,6 +1286,14 @@ export function newCastValue(
return value;
}

// Workaround: solve the case that the return value of a generic function is a generic array
if (
value_type.kind !== ValueTypeKind.GENERIC &&
type.kind === ValueTypeKind.GENERIC
) {
return value;
}

throw Error(`cannot make cast value from "${value_type}" to "${type}"`);
}

Expand Down Expand Up @@ -1758,6 +1776,16 @@ class GuessTypeArguments {
return;
}

if (templateType.kind == ValueTypeKind.UNION) {
const unionType = templateType as UnionType;
unionType.types.forEach((t) => {
if (t.kind == ValueTypeKind.TYPE_PARAMETER) {
this.updateTypeMap(t as TypeParameterType, valueType);
}
});
return;
}

if (valueType.kind != templateType.kind) {
throw Error(
`Cannot guess the value type: template: ${templateType}, valueType: ${valueType}`,
Expand All @@ -1779,9 +1807,6 @@ class GuessTypeArguments {
valueType as FunctionType,
);
break;

case ValueTypeKind.UNION:
break; // TODO
}
}

Expand Down Expand Up @@ -1981,7 +2006,6 @@ function buildCallExpression(
func_type,
specialTypeArgs,
) as FunctionType;
func_type.setSpecialTypeArguments(specialTypeArgs);
(func as FunctionCallBaseValue).funcType = func_type;
}

Expand Down
Loading

0 comments on commit 3ac20dd

Please sign in to comment.