Skip to content
This repository has been archived by the owner on Jun 22, 2021. It is now read-only.

Commit

Permalink
implement scitype for "points on a manifold" #46
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Sep 23, 2020
1 parent fd9ff43 commit 4411140
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 9 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ version = "0.3.0"
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
ManifoldsBase = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
CategoricalArrays = "^0.8"
ColorTypes = "^0.9,^0.10"
ManifoldsBase = "^0.9.5"
PrettyTables = "^0.8,^0.9"
ScientificTypes = "^1.0"
Tables = "^1.0"
Expand Down
5 changes: 3 additions & 2 deletions src/MLJScientificTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ module MLJScientificTypes

# Dependencies
using ScientificTypes
using Tables, CategoricalArrays, ColorTypes, PrettyTables, Dates
using Tables, CategoricalArrays, ColorTypes, PrettyTables, Dates,
ManifoldsBase

# re-exports from ScientificTypes
export Scientific, Found, Unknown, Known, Finite, Infinite,
OrderedFactor, Multiclass, Count, Continuous, Textual,
Binary, ColorImage, GrayImage, Image, Table,
ScientificTimeType, ScientificDate, ScientificDateTime,
ScientificTime
ScientificTime, ManifoldPoint
export scitype, scitype_union, elscitype, nonmissing, trait

# exports
Expand Down
29 changes: 22 additions & 7 deletions src/convention/scitype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ function ST.scitype(A::CArr{T,N}, ::MLJ) where {T,N}
return AbstractArray{S,N}
end

# Manifold scitype

ST.scitype(::Tuple{Any,MT}) where MT<:ManifoldsBase.Manifold = ManifoldPoint{MT}

# Table scitype

function ST.scitype(X, ::MLJ, ::Val{:table}; kw...)
Expand All @@ -39,10 +43,21 @@ end

# Scitype for fast array broadcasting

ST.Scitype(::Type{<:Integer}, ::MLJ) = Count
ST.Scitype(::Type{<:AbstractFloat}, ::MLJ) = Continuous
ST.Scitype(::Type{<:AbstractString}, ::MLJ) = Textual
ST.Scitype(::Type{<:TimeType}, ::MLJ) = ScientificTimeType
ST.Scitype(::Type{<:Date}, ::MLJ) = ScientificDate
ST.Scitype(::Type{<:Time}, ::MLJ) = ScientificTime
ST.Scitype(::Type{<:DateTime}, ::MLJ) = ScientificDateTime
const Point{MT} = Tuple{Any,MT}
const Manifold = ManifoldsBase.Manifold

ST.Scitype(::Type{<:Integer}, ::MLJ) = Count
ST.Scitype(::Type{<:AbstractFloat}, ::MLJ) = Continuous
ST.Scitype(::Type{<:AbstractString}, ::MLJ) = Textual
ST.Scitype(::Type{<:TimeType}, ::MLJ) = ScientificTimeType
ST.Scitype(::Type{<:Date}, ::MLJ) = ScientificDate
ST.Scitype(::Type{<:Time}, ::MLJ) = ScientificTime
ST.Scitype(::Type{<:DateTime}, ::MLJ) = ScientificDateTime

# Next two lines don't work https://github.com/JuliaLang/julia/issues/37703 :
# ST.Scitype(::Type{<:Point{MT}}, ::MLJ) where MT<:ManifoldsBase.Manifold =
# ManifoldPoint{MT}

# TODO: Remove the following hack when above issue is resolved:
ST.Scitype(T::Type{<:Point{<:Manifold}}, ::MLJ) = ManifoldPoint{last(T.types)}

17 changes: 17 additions & 0 deletions test/basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,23 @@ end
AbstractVector{Union{Missing,ScientificTimeType}}
end

struct MySphere{N} <: ManifoldsBase.Manifold{ManifoldsBase.ℝ} where {N}
radius::Float64
end
MySphere(radius, n) = MySphere{n}(radius)

@testset "manifold point" begin
manifold1 = MySphere(1, 3)
@test scitype(("some_point_representation", manifold1)) ==
ManifoldPoint{MySphere{3}}
v1 = [(rand(), manifold1) for _ in 1:4]
@test elscitype(v1) == ManifoldPoint{MySphere{3}}
@test scitype(v1) == AbstractVector{ManifoldPoint{MySphere{3}}}
manifold2 = MySphere(1, 4)
v2 = [(rand(), manifold2) for _ in 1:3]
@test scitype(vcat(v1, v2)) <: AbstractVector{<:ManifoldPoint{<:MySphere}}
end

@testset "Type coercion" begin
X = (x=10:10:44, y=1:4, z=collect("abcd"))
types = Dict(:x => Continuous, :z => Multiclass)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Test, ScientificTypes, MLJScientificTypes, Random
using Tables, CategoricalArrays, CSV, DataFrames, ColorTypes
import ManifoldsBase
using Dates

const Arr = AbstractArray
Expand Down

0 comments on commit 4411140

Please sign in to comment.