Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix behavior for nested types #1

Merged
merged 7 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,30 @@ extracts the base numeric type from a numeric type `T`:
For example,

| Input Type | Output Type |
|---|---|
|:-:|---|
| `Float32` | `Float32` |
| `ComplexF32` | `Float32` |
| `Measurement{Float32}` | `Float32` |
| `Dual{BigFloat}` | `BigFloat` |
| `Dual{ComplexF32}` | `Float32` |
| `Rational{Int8}` | `Int8` |
| `Quantity{Float32,Dimensions}` | `Float32` |
| `Quantity{Float32, ...}` | `Float32` |
| `Quantity{Measurement{Float32}, ...}` | `Float32` |

Packages should write a method to `base_numeric_type`
when the base type of a numeric type
is not the first parametric type.
For example, if you were to create a quantity-like type
`Quantity{Dimensions,NumericType}`, you would need
to write a custom interface.
Package maintainers should write a specialized method for their type.
For example, to define the base numeric type for a dual number, one could write:

But if the base type comes first,
the default method will work.
```julia
import BaseType: base_numeric_type

base_numeric_type(::Type{Dual{T}}) where {T} = base_numeric_type(T)
```

It is important to call `base_numeric_type` recursively like this to deal with
nested numeric types such as `Quantity{Measurement{T}}`.

The fallback behavior of `base_numeric_type` is to return the *first* type parameter,
or, if that type has parameters of its own (such as `Dual{Complex{Float32}}`),
to recursively take the first type parameter until a non-parameterized type is found.
This works for the vast majority of types, but it is still preferred
if package maintainers write a specialized method.
27 changes: 22 additions & 5 deletions src/BaseType.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,35 @@ as a measurement or a quantity.
For example,

| Input Type | Output Type |
|---|---|
|:-:|---|
| `Float32` | `Float32` |
| `ComplexF32` | `Float32` |
| `Measurement{Float32}` | `Float32` |
| `Rational{Int8}` | `Int8` |
| `Dual{BigFloat}` | `BigFloat` |
| `Quantity{Float32,Dimensions}` | `Float32` |
| `Rational{Int8}` | `Int8` |
| `Quantity{Float32, ...}` | `Float32` |
| `Quantity{Measurement{Float32}, ...}` | `Float32` |
| `Dual{Complex{Float32}}` | `Float32` |

The standard behavior is to return the *first* type parameter,
or, if that type has parameters of its own (such as `Dual{Complex{Float32}}`),
to recursively take the first type parameter until a non-parameterized type is found.
"""
@generated function base_numeric_type(::Type{T}) where {T}
params = T isa UnionAll ? T.body.parameters : T.parameters
return isempty(params) ? :($T) : :($(first(params)))
# This uses a generated function for type stability in Julia <=1.9,
# though in Julia >=1.10 it is not necessary.
# TODO: switch to non-generated when Julia >= 1.10 is LTS.
return :($(_base_numeric_type(T)))
end
base_numeric_type(x) = base_numeric_type(typeof(x))

function _base_numeric_type(::Type{T}) where {T}
params = T isa UnionAll ? T.body.parameters : T.parameters
if isempty(params)
return T
else
return _base_numeric_type(first(params))
end
end

end
60 changes: 41 additions & 19 deletions test/unittests.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,49 @@
using Test: @test, @inferred
using Test: @test, @testset, @inferred
using BaseType: base_numeric_type
using DualNumbers: DualNumbers
using DualNumbers: DualNumbers, Dual
using DynamicQuantities: DynamicQuantities
using Measurements: ±
using Unitful: Unitful

expected_type_pairs = [
Float32 => Float32,
ComplexF64 => Float64,
DualNumbers.Dual{Int64} => Int64,
DynamicQuantities.Quantity{Float32} => Float32,
typeof(1.5DynamicQuantities.u"km/s") => Float64,
typeof(1.5f0Unitful.u"km/s") => Float32,
BigFloat => BigFloat,
typeof(1.5 ± 0.2) => Float64,
typeof(1.5f0 ± 0.2f0) => Float32,
]
@testset "Basic usage" begin
expected_type_pairs = [
Float32 => Float32,
ComplexF64 => Float64,
DualNumbers.Dual{Int64} => Int64,
DynamicQuantities.Quantity{Float32} => Float32,
typeof(1.5DynamicQuantities.u"km/s") => Float64,
typeof(1.5f0Unitful.u"km/s") => Float32,
BigFloat => BigFloat,
typeof(1.5 ± 0.2) => Float64,
typeof(1.5f0 ± 0.2f0) => Float32,
]

for (x, y) in expected_type_pairs
@eval @test base_numeric_type($x) == $y
# Make sure compiler can inline it:
@eval @inferred $y base_numeric_type($x)
for (x, y) in expected_type_pairs
@eval @test base_numeric_type($x) == $y
# Make sure compiler can inline it:
@eval @inferred $y base_numeric_type($x)
end

@test base_numeric_type(1.5DynamicQuantities.u"km/s") == base_numeric_type(typeof(1.5DynamicQuantities.u"km/s"))
@inferred base_numeric_type(1.5DynamicQuantities.u"km/s")
end

@testset "Nested types" begin
# Quantity ∘ Measurement:
x = 5Unitful.u"m/s" ± 0.1Unitful.u"m/s"
@test base_numeric_type(x) == Float64

# Quantity ∘ Dual:
y = Dual(1.0)Unitful.u"m/s"
@test base_numeric_type(y) == Float64
end

@test base_numeric_type(1.5DynamicQuantities.u"km/s") == base_numeric_type(typeof(1.5DynamicQuantities.u"km/s"))
@inferred base_numeric_type(1.5DynamicQuantities.u"km/s")
struct Node{T}
child::Union{Node{T},Nothing}
value::T
end

@testset "Safe default behavior for recursive types" begin
c = Node{Int}(Node{Int}(nothing, 1), 2)
@test base_numeric_type(c) == Int
end
Loading