Skip to content

Commit

Permalink
Deprecates absent types
Browse files Browse the repository at this point in the history
Removes all logic regarding absent types in planner
  • Loading branch information
johnedquinn committed May 15, 2024
1 parent 2a13661 commit 2154f45
Show file tree
Hide file tree
Showing 26 changed files with 632 additions and 706 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ internal fun StaticType.cast(targetType: StaticType): StaticType {

// union source types, recursively process them
when (this) {
is AnyType -> return AnyOfType(this.toAnyOfType().types.map { it.cast(targetType) }.toSet()).flatten()
is AnyType -> return StaticType.unionOf(this.toAnyOfType().types.map { it.cast(targetType) }.toSet()).flatten()
is AnyOfType -> return when (val flattened = this.flatten()) {
is SingleType, is AnyType -> flattened.cast(targetType)
is AnyOfType -> AnyOfType(flattened.types.map { it.cast(targetType) }.toSet()).flatten()
is AnyOfType -> StaticType.unionOf(flattened.types.map { it.cast(targetType) }.toSet()).flatten()
}
}

Expand Down Expand Up @@ -221,7 +221,7 @@ internal fun StaticType.cast(targetType: StaticType): StaticType {
*/
internal fun StaticType.filterNullMissing(): StaticType =
when (this) {
is AnyOfType -> AnyOfType(this.types.filter { !it.isNullOrMissing() }.toSet()).flatten()
is AnyOfType -> StaticType.unionOf(this.types.filter { !it.isNullOrMissing() }.toSet()).flatten()
else -> this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ import org.partiql.lang.util.isPosInf
import org.partiql.lang.util.ln
import org.partiql.lang.util.power
import org.partiql.lang.util.squareRoot
import org.partiql.types.AnyOfType
import org.partiql.types.StaticType
import org.partiql.types.StaticType.Companion.unionOf
import java.math.BigDecimal
Expand Down Expand Up @@ -342,7 +341,7 @@ internal object ExprFunctionUpper : ExprFunction {

override val signature = FunctionSignature(
name = "upper",
requiredParameters = listOf(AnyOfType(setOf(StaticType.STRING, StaticType.SYMBOL))),
requiredParameters = listOf(StaticType.unionOf(setOf(StaticType.STRING, StaticType.SYMBOL))),
returnType = StaticType.STRING
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,9 +477,9 @@ internal class StaticTypeInferenceVisitorTransform(
// where the other arg is an incompatible type (not an unknown or bool), the result type is MISSING.
args.any { it == StaticType.BOOL } -> when {
// If other argument is missing, then return union(bool, missing)
args.any { it is MissingType } -> AnyOfType(setOf(StaticType.MISSING, StaticType.BOOL))
args.any { it is MissingType } -> StaticType.unionOf(setOf(StaticType.MISSING, StaticType.BOOL))
// If other argument is null, then return union(bool, null)
args.any { it is NullType } -> AnyOfType(setOf(StaticType.NULL, StaticType.BOOL))
args.any { it is NullType } -> StaticType.unionOf(setOf(StaticType.NULL, StaticType.BOOL))
// If other type is anything other than null or missing, then it is an error case
else -> StaticType.MISSING
}
Expand Down Expand Up @@ -1021,7 +1021,7 @@ internal class StaticTypeInferenceVisitorTransform(
StaticType.NULL.takeIf { actualType.allTypes.any { it is NullType } }
)
}
AnyOfType(finalReturnTypes.toSet()).flatten()
StaticType.unionOf(finalReturnTypes.toSet()).flatten()
} else {
// otherwise, has an invalid arg type and errors. continuation type of [FunctionSignature.returnType]
signature.returnType
Expand Down Expand Up @@ -1093,7 +1093,7 @@ internal class StaticTypeInferenceVisitorTransform(
values: List<PartiqlAst.Expr>,
compute: (StaticType) -> StaticType
): PartiqlAst.Expr {
val valuesTypes = AnyOfType(values.getStaticType().toSet()).flatten()
val valuesTypes = StaticType.unionOf(values.getStaticType().toSet()).flatten()
val inferredType = compute(valuesTypes)
return expr.withStaticType(inferredType)
}
Expand Down Expand Up @@ -1201,7 +1201,7 @@ internal class StaticTypeInferenceVisitorTransform(
}

val possibleTypes = thenExprsTypes + elseExprType
return AnyOfType(possibleTypes.toSet()).flatten()
return StaticType.unionOf(possibleTypes.toSet()).flatten()
}

// PIG ast Types => CanCast, CanLosslessCast, IsType, ExprCast
Expand Down Expand Up @@ -1319,7 +1319,7 @@ internal class StaticTypeInferenceVisitorTransform(
is BagType -> fromSourceType.elementType
is ListType -> fromSourceType.elementType
is AnyType -> StaticType.ANY
is AnyOfType -> AnyOfType(fromSourceType.types.map { getElementTypeForFromSource(it) }.toSet())
is AnyOfType -> StaticType.unionOf(fromSourceType.types.map { getElementTypeForFromSource(it) }.toSet())
// All the other types coerce into a bag of themselves (including null/missing/sexp).
else -> fromSourceType
}
Expand Down Expand Up @@ -1419,13 +1419,13 @@ internal class StaticTypeInferenceVisitorTransform(
private fun getUnpivotValueType(fromSourceType: StaticType): StaticType =
when (fromSourceType) {
is StructType -> if (fromSourceType.contentClosed) {
AnyOfType(fromSourceType.fields.map { it.value }.toSet()).flatten()
StaticType.unionOf(fromSourceType.fields.map { it.value }.toSet()).flatten()
} else {
// Content is open, so value can be of any type
StaticType.ANY
}
is AnyType -> StaticType.ANY
is AnyOfType -> AnyOfType(fromSourceType.types.map { getUnpivotValueType(it) }.toSet())
is AnyOfType -> StaticType.unionOf(fromSourceType.types.map { getUnpivotValueType(it) }.toSet())
// All the other types coerce into a struct of themselves with synthetic key names
else -> fromSourceType
}
Expand Down Expand Up @@ -1506,7 +1506,7 @@ internal class StaticTypeInferenceVisitorTransform(
StaticType.ANY
} else {
val staticTypes = prevTypes.map { inferPathComponentExprType(it, currentPathComponent) }
AnyOfType(staticTypes.toSet()).flatten()
StaticType.unionOf(staticTypes.toSet()).flatten()
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ internal object PlanTransform : PlanBaseVisitor<PlanNode, ProblemCallback>() {
org.partiql.plan.Agg(
FunctionSignature.Aggregation(
"UNKNOWN_AGG::$name",
returns = PartiQLValueType.MISSING,
returns = PartiQLValueType.ANY,
parameters = emptyList()
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,11 @@ internal object RelConverter {
}

private fun convertProjectionItem(item: Select.Project.Item) = when (item) {
is Select.Project.Item.All -> convertProjectItemAll(item)
is Select.Project.Item.All -> convertProjectItemAll()
is Select.Project.Item.Expression -> convertProjectItemRex(item)
}

private fun convertProjectItemAll(item: Select.Project.Item.All): Pair<Rel.Binding, Rex> {
private fun convertProjectItemAll(): Pair<Rel.Binding, Rex> {
throw IllegalArgumentException("AST not normalized")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ import org.partiql.planner.internal.ir.rexOpStructField
import org.partiql.planner.internal.ir.rexOpSubquery
import org.partiql.planner.internal.ir.rexOpTupleUnion
import org.partiql.planner.internal.ir.rexOpVarUnresolved
import org.partiql.planner.internal.typer.toNonNullStaticType
import org.partiql.planner.internal.typer.toStaticType
import org.partiql.types.StaticType
import org.partiql.types.TimeType
Expand Down Expand Up @@ -70,10 +69,7 @@ internal object RexConverter {
throw IllegalArgumentException("unsupported rex $node")

override fun visitExprLit(node: Expr.Lit, context: Env): Rex {
val type = when (node.value.isNull) {
true -> node.value.type.toStaticType()
else -> node.value.type.toNonNullStaticType()
}
val type = node.value.type.toStaticType()
val op = rexOpLit(node.value)
return rex(type, op)
}
Expand All @@ -82,10 +78,7 @@ internal object RexConverter {
val value =
PartiQLValueIonReaderBuilder
.standard().build(node.value).read()
val type = when (value.isNull) {
true -> value.type.toStaticType()
else -> value.type.toNonNullStaticType()
}
val type = value.type.toStaticType()
return rex(type, rexOpLit(value))
}

Expand Down Expand Up @@ -287,7 +280,7 @@ internal object RexConverter {
}.toMutableList()

val defaultRex = when (val default = node.default) {
null -> rex(type = StaticType.NULL, op = rexOpLit(value = nullValue()))
null -> rex(type = StaticType.ANY, op = rexOpLit(value = nullValue()))
else -> visitExprCoerce(default, context)
}
val op = rexOpCase(branches = branches, default = defaultRex)
Expand Down Expand Up @@ -528,8 +521,8 @@ internal object RexConverter {
val type = node.asType
val arg0 = visitExprCoerce(node.value, ctx)
return when (type) {
is Type.NullType -> rex(StaticType.NULL, call("cast_null", arg0))
is Type.Missing -> rex(StaticType.MISSING, call("cast_missing", arg0))
is Type.NullType -> error("Cannot cast any value to NULL")
is Type.Missing -> error("Cannot cast any value to MISSING")
is Type.Bool -> rex(StaticType.BOOL, call("cast_bool", arg0))
is Type.Tinyint -> TODO("Static Type does not have TINYINT type")
is Type.Smallint, is Type.Int2 -> rex(StaticType.INT2, call("cast_int16", arg0))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

package org.partiql.planner.internal.typer

import org.partiql.types.MissingType
import org.partiql.types.NullType
import org.partiql.planner.internal.ir.Rex
import org.partiql.types.StaticType
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.PartiQLValueType
Expand All @@ -27,8 +26,6 @@ import org.partiql.value.PartiQLValueType.INT64
import org.partiql.value.PartiQLValueType.INT8
import org.partiql.value.PartiQLValueType.INTERVAL
import org.partiql.value.PartiQLValueType.LIST
import org.partiql.value.PartiQLValueType.MISSING
import org.partiql.value.PartiQLValueType.NULL
import org.partiql.value.PartiQLValueType.SEXP
import org.partiql.value.PartiQLValueType.STRING
import org.partiql.value.PartiQLValueType.STRUCT
Expand Down Expand Up @@ -57,59 +54,57 @@ internal class DynamicTyper {
private var supertype: PartiQLValueType? = null
private var args = mutableListOf<PartiQLValueType>()

private var nullable = false
private var missable = false
private val types = mutableSetOf<StaticType>()

/**
* This primarily unpacks a StaticType because of NULL, MISSING.
*
* - T
* - NULL
* - MISSING
* - (NULL)
* - (MISSING)
* - (T..)
* - (T..|NULL)
* - (T..|MISSING)
* - (T..|NULL|MISSING)
* - (NULL|MISSING)
*
* When a literal null or missing value is present in the query, its type is unknown. Therefore, its type must be
* inferred. This function ignores literal null/missing values, yet adds their indices to know how to return the
* mapping.
*/
fun accumulateUnknown() {
args.add(ANY)
}

fun accumulate(rex: Rex) {
when (rex.isLiteralAbsent()) {
true -> accumulateUnknown()
false -> accumulate(rex.type)
}
}

private fun Rex.isLiteralAbsent(): Boolean {
val op = this.op
return op is Rex.Op.Lit && (op.value.type == PartiQLValueType.MISSING || op.value.type == PartiQLValueType.NULL)
}

/**
* This cleans the input type.
* @param type
*/
fun accumulate(type: StaticType) {
val nonAbsentTypes = mutableSetOf<StaticType>()
val flatType = type.flatten()
if (flatType == StaticType.ANY) {
// Use ANY runtime; do not expand ANY
types.add(flatType)
args.add(ANY)
calculate(ANY)
return
}
for (t in flatType.allTypes) {
when (t) {
is NullType -> nullable = true
is MissingType -> missable = true
else -> nonAbsentTypes.add(t)
}
}
when (nonAbsentTypes.size) {
val allTypes = flatType.allTypes
when (allTypes.size) {
0 -> {
// Ignore in calculating supertype.
args.add(NULL)
error("This should not have happened.")
}
1 -> {
// Had single type
val single = nonAbsentTypes.first()
val single = allTypes.first()
val singleRuntime = single.toRuntimeType()
types.add(single)
args.add(singleRuntime)
calculate(singleRuntime)
}
else -> {
// Had a union; use ANY runtime
types.addAll(nonAbsentTypes)
types.addAll(allTypes)
args.add(ANY)
calculate(ANY)
}
Expand All @@ -124,31 +119,20 @@ internal class DynamicTyper {
* @return
*/
fun mapping(): Pair<StaticType, List<Pair<PartiQLValueType, PartiQLValueType>>?> {
val modifiers = mutableSetOf<StaticType>()
if (nullable) modifiers.add(StaticType.NULL)
if (missable) modifiers.add(StaticType.MISSING)
// If at top supertype, then return union of all accumulated types
if (supertype == ANY) {
return StaticType.unionOf(types + modifiers).flatten() to null
return StaticType.unionOf(types).flatten() to null
}
// If a collection, then return union of all accumulated types as these coercion rules are not defined by SQL.
if (supertype == STRUCT || supertype == BAG || supertype == LIST || supertype == SEXP) {
return StaticType.unionOf(types + modifiers) to null
return StaticType.unionOf(types) to null
}
// If not initialized, then return null, missing, or null|missing.
val s = supertype
if (s == null) {
val t = if (modifiers.isEmpty()) StaticType.MISSING else StaticType.unionOf(modifiers).flatten()
return t to null
}
val s = supertype ?: return StaticType.ANY to null
// Otherwise, return the supertype along with the coercion mapping
val type = s.toNonNullStaticType()
val type = s.toStaticType()
val mapping = args.map { it to s }
return if (modifiers.isEmpty()) {
type to mapping
} else {
StaticType.unionOf(setOf(type) + modifiers).flatten() to mapping
}
return type to mapping
}

private fun calculate(type: PartiQLValueType) {
Expand All @@ -163,7 +147,7 @@ internal class DynamicTyper {
// Lookup and set the new minimum common supertype
supertype = when {
type == ANY -> type
type == NULL || type == MISSING || s == type -> return // skip
s == type -> return // skip
else -> graph[s][type] ?: ANY // lookup, if missing then go to top.
}
}
Expand Down Expand Up @@ -206,8 +190,6 @@ internal class DynamicTyper {
graph[type] = arrayOfNulls(N)
}
graph[ANY] = edges()
graph[NULL] = edges()
graph[MISSING] = edges()
graph[BOOL] = edges(
BOOL to BOOL
)
Expand Down
Loading

0 comments on commit 2154f45

Please sign in to comment.