Skip to content

Commit

Permalink
Merge pull request #1 from SymbolicML/nested-types
Browse files Browse the repository at this point in the history
Fix behavior for nested types
  • Loading branch information
MilesCranmer authored Sep 25, 2023
2 parents 8360881 + d895e37 commit 4951963
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 34 deletions.
30 changes: 20 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,30 @@ extracts the base numeric type from a numeric type `T`:
For example,

| Input Type | Output Type |
|---|---|
|:-:|---|
| `Float32` | `Float32` |
| `ComplexF32` | `Float32` |
| `Measurement{Float32}` | `Float32` |
| `Dual{BigFloat}` | `BigFloat` |
| `Dual{ComplexF32}` | `Float32` |
| `Rational{Int8}` | `Int8` |
| `Quantity{Float32,Dimensions}` | `Float32` |
| `Quantity{Float32, ...}` | `Float32` |
| `Quantity{Measurement{Float32}, ...}` | `Float32` |

Packages should write a method to `base_numeric_type`
when the base type of a numeric type
is not the first parametric type.
For example, if you were to create a quantity-like type
`Quantity{Dimensions,NumericType}`, you would need
to write a custom interface.
Package maintainers should write a specialized method for their type.
For example, to define the base numeric type for a dual number, one could write:

But if the base type comes first,
the default method will work.
```julia
import BaseType: base_numeric_type

base_numeric_type(::Type{Dual{T}}) where {T} = base_numeric_type(T)
```

It is important to call `base_numeric_type` recursively like this to deal with
nested numeric types such as `Quantity{Measurement{T}}`.

The fallback behavior of `base_numeric_type` is to return the *first* type parameter,
or, if that type has parameters of its own (such as `Dual{Complex{Float32}}`),
to recursively take the first type parameter until a non-parameterized type is found.
This works for the vast majority of types, but it is still preferred
if package maintainers write a specialized method.
27 changes: 22 additions & 5 deletions src/BaseType.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,35 @@ as a measurement or a quantity.
For example,
| Input Type | Output Type |
|---|---|
|:-:|---|
| `Float32` | `Float32` |
| `ComplexF32` | `Float32` |
| `Measurement{Float32}` | `Float32` |
| `Rational{Int8}` | `Int8` |
| `Dual{BigFloat}` | `BigFloat` |
| `Quantity{Float32,Dimensions}` | `Float32` |
| `Rational{Int8}` | `Int8` |
| `Quantity{Float32, ...}` | `Float32` |
| `Quantity{Measurement{Float32}, ...}` | `Float32` |
| `Dual{Complex{Float32}}` | `Float32` |
The standard behavior is to return the *first* type parameter,
or, if that type has parameters of its own (such as `Dual{Complex{Float32}}`),
to recursively take the first type parameter until a non-parameterized type is found.
"""
@generated function base_numeric_type(::Type{T}) where {T}
params = T isa UnionAll ? T.body.parameters : T.parameters
return isempty(params) ? :($T) : :($(first(params)))
# This uses a generated function for type stability in Julia <=1.9,
# though in Julia >=1.10 it is not necessary.
# TODO: switch to non-generated when Julia >= 1.10 is LTS.
return :($(_base_numeric_type(T)))
end
base_numeric_type(x) = base_numeric_type(typeof(x))

function _base_numeric_type(::Type{T}) where {T}
params = T isa UnionAll ? T.body.parameters : T.parameters
if isempty(params)
return T
else
return _base_numeric_type(first(params))
end
end

end
60 changes: 41 additions & 19 deletions test/unittests.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,49 @@
using Test: @test, @inferred
using Test: @test, @testset, @inferred
using BaseType: base_numeric_type
using DualNumbers: DualNumbers
using DualNumbers: DualNumbers, Dual
using DynamicQuantities: DynamicQuantities
using Measurements: ±
using Unitful: Unitful

expected_type_pairs = [
Float32 => Float32,
ComplexF64 => Float64,
DualNumbers.Dual{Int64} => Int64,
DynamicQuantities.Quantity{Float32} => Float32,
typeof(1.5DynamicQuantities.u"km/s") => Float64,
typeof(1.5f0Unitful.u"km/s") => Float32,
BigFloat => BigFloat,
typeof(1.5 ± 0.2) => Float64,
typeof(1.5f0 ± 0.2f0) => Float32,
]
@testset "Basic usage" begin
expected_type_pairs = [
Float32 => Float32,
ComplexF64 => Float64,
DualNumbers.Dual{Int64} => Int64,
DynamicQuantities.Quantity{Float32} => Float32,
typeof(1.5DynamicQuantities.u"km/s") => Float64,
typeof(1.5f0Unitful.u"km/s") => Float32,
BigFloat => BigFloat,
typeof(1.5 ± 0.2) => Float64,
typeof(1.5f0 ± 0.2f0) => Float32,
]

for (x, y) in expected_type_pairs
@eval @test base_numeric_type($x) == $y
# Make sure compiler can inline it:
@eval @inferred $y base_numeric_type($x)
for (x, y) in expected_type_pairs
@eval @test base_numeric_type($x) == $y
# Make sure compiler can inline it:
@eval @inferred $y base_numeric_type($x)
end

@test base_numeric_type(1.5DynamicQuantities.u"km/s") == base_numeric_type(typeof(1.5DynamicQuantities.u"km/s"))
@inferred base_numeric_type(1.5DynamicQuantities.u"km/s")
end

@testset "Nested types" begin
# Quantity ∘ Measurement:
x = 5Unitful.u"m/s" ± 0.1Unitful.u"m/s"
@test base_numeric_type(x) == Float64

# Quantity ∘ Dual:
y = Dual(1.0)Unitful.u"m/s"
@test base_numeric_type(y) == Float64
end

@test base_numeric_type(1.5DynamicQuantities.u"km/s") == base_numeric_type(typeof(1.5DynamicQuantities.u"km/s"))
@inferred base_numeric_type(1.5DynamicQuantities.u"km/s")
struct Node{T}
child::Union{Node{T},Nothing}
value::T
end

@testset "Safe default behavior for recursive types" begin
c = Node{Int}(Node{Int}(nothing, 1), 2)
@test base_numeric_type(c) == Int
end

0 comments on commit 4951963

Please sign in to comment.