Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

first draft of OLS #14

Merged
merged 12 commits into from
Sep 26, 2024
1 change: 1 addition & 0 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"ghcr.io/devcontainers-contrib/features/clojure-asdf:2": {},
"ghcr.io/devcontainers-contrib/features/tmux-apt-get:1": {},
"ghcr.io/devcontainers-contrib/features/redis-homebrew:1" : {},
"ghcr.io/rocker-org/devcontainer-features/r-apt:latest": {},
"ghcr.io/devcontainers-contrib/features/bash-command:1": {"command": "apt-get update && apt-get install -y rlwrap"}


Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
unreleased

- added OLS
- added tidy output validation

0.8.2
-fixed metric bug

Expand Down
2 changes: 1 addition & 1 deletion resources/columms-glance.edn
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
;; :missing-method "Method for eliminating missing data.",
;; :model "A character string denoting the model at which the optimal BIC occurs.",
:mse "Mean Squared Deviation"
:n "The total number of observations.",
:n "The total number of observations.",
;; :n.clusters "Number of clusters.",
;; :n.factors "The number of fitted factors.",
;; :n.max "Maximum number of subjects at risk.",
Expand Down
15 changes: 12 additions & 3 deletions src/scicloj/metamorph/ml.clj
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
[tech.v3.datatype.errors :as errors]
[tech.v3.datatype.export-symbols :as exporter]
[tech.v3.datatype.functional :as dfn]
[scicloj.metamorph.ml.tidy-models :as tidy]
[clojure.set :as set])
;;

Expand Down Expand Up @@ -715,10 +716,14 @@
(get
(options->model-def (:options model))
:tidy-fn)]

(if tidy-fn
(tidy-fn model)
(tidy/validate-tidy-ds
(tidy-fn model))
(ds/->dataset {}))))



(defn glance
"Gives a glance on the model, returning a dataset with model information
about the entire model.
Expand All @@ -738,10 +743,12 @@
(options->model-def (:options model))
:glance-fn)]
(if glance-fn
(glance-fn model)
(tidy/validate-glance-ds
(glance-fn model))
(ds/->dataset {}))))



(defn augment
"
Adds informations about observations to a dataset
Expand All @@ -763,7 +770,9 @@
(options->model-def (:options model))
:augment-fn)]
(if augment-fn
(augment-fn model data)
(tidy/validate-augment-ds
(augment-fn model data)
data)
data)))


Expand Down
79 changes: 79 additions & 0 deletions src/scicloj/metamorph/ml/regression.clj
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
(ns scicloj.metamorph.ml.regression
(:require [scicloj.metamorph.ml :as ml]
[tech.v3.dataset :as ds]
[tablecloth.api :as tc]
[tech.v3.datatype :as dt]
[tech.v3.dataset.modelling :as ds-mod]
[scicloj.metamorph.ml.toydata :as data]
[tech.v3.dataset.tensor :as dtt])
(:import [org.apache.commons.math3.stat.regression OLSMultipleLinearRegression]))



(defn- tidy-ols [model]
(ds/->dataset
{:term
(concat (:target-columns model)
(:feature-columns model))

:estimate
(.estimateRegressionParameters (:model-data model))
:std.error
(.estimateRegressionParametersStandardErrors (:model-data model))}))


(defn- augment-fn [model data]
(-> data
(tc/add-column :.resid (.estimateResiduals (:model-data model)))))


(defn- glance-ols [model]

(ds/->dataset
{
:totss
(.calculateTotalSumOfSquares (:model-data model))
:adj.r.squared
(.calculateAdjustedRSquared (:model-data model))
:rss
(.calculateResidualSumOfSquares (:model-data model))

;; (.estimateRegressandVariance (:model-data model)) ; TODO what this ?
:sigma
(.estimateErrorVariance (:model-data model))}))

(defn- train-ols [feature-ds target-ds options]
(let [
values
(->
(tc/append target-ds feature-ds)

(dtt/dataset->tensor)
(dt/->double-array))
ds
(->
(data/mtcars-ds)
(ds/drop-columns [:model])
(ds-mod/set-inference-target :mpg))
shape
(ds/shape ds)


ols (OLSMultipleLinearRegression.)
_
(.newSampleData ols values
(second shape)
(dec (first shape)))]
ols))

(defn- predict-ols [feature-ds thawed-model model]
(throw "Prediction is not supported by this model."))


(ml/define-model! :metamorph.ml/ols
train-ols
predict-ols
{
:tidy-fn tidy-ols
:glance-fn glance-ols
:augment-fn augment-fn})
58 changes: 58 additions & 0 deletions src/scicloj/metamorph/ml/tidy_models.clj
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
(ns scicloj.metamorph.ml.tidy-models
(:require
[clojure.edn :as edn]
[clojure.set :as set]
[tech.v3.dataset :as ds]))

(def ^:dynamic
^{:doc "Controls if the result columns of the tidy fns of a model
(glance-fn, tidy-fn, augment-fn is validated against these base
https://github.com/scicloj/metamorph.ml/tree/main/resources/*.edn
and if on violation they fail."
:added "1.0"}
*validate-tidy-fns* true)

(defn allowed-glance-columns []
(keys
(edn/read-string (slurp "https://raw.githubusercontent.com/scicloj/metamorph.ml/linearRegression/resources/columms-glance.edn"))))

(defn allowed-tidy-columns []
(keys
(edn/read-string (slurp "https://raw.githubusercontent.com/scicloj/metamorph.ml/linearRegression/resources/columms-tidy.edn"))))

(defn allowed-augment-columns []
(keys
(edn/read-string (slurp "https://raw.githubusercontent.com/scicloj/metamorph.ml/linearRegression/resources/columms-augment.edn"))))

(defn _get-allowed-keys []
{:glance (allowed-glance-columns)
:tidy (allowed-tidy-columns)
:augment (allowed-augment-columns)})

(def get-allowed-keys (memoize _get-allowed-keys))



(defn- validate-ds [ds allowed-columns fn-name]
(if (true? *validate-tidy-fns*)
(let [
invalid-keys
(set/difference
(into #{} (ds/column-names ds))
(into #{} allowed-columns))]
(if (empty? invalid-keys)
ds
(throw (Exception. (format "invalid keys from %s: %s" fn-name invalid-keys)))))
ds))

(defn validate-tidy-ds [ds]
(validate-ds ds (:tidy (get-allowed-keys)) "tidy-fn"))

(defn validate-glance-ds [ds]
(validate-ds ds (:glance (get-allowed-keys)) "glance-fn"))

(defn validate-augment-ds [ds data]
(validate-ds
ds
(concat (:augment (get-allowed-keys)) (ds/column-names data))
"augment-fn"))
122 changes: 122 additions & 0 deletions test/scicloj/metamorph/linear_regression_test.clj
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
(ns scicloj.metamorph.linear-regression-test
(:require
[scicloj.metamorph.ml :as ml]
[scicloj.metamorph.ml.toydata :as data]
[tech.v3.dataset :as ds]
[scicloj.metamorph.ml.regression]
[tech.v3.dataset.modelling :as ds-mod]))

(def ds
(->
(data/mtcars-ds)
(ds/drop-columns [:model])
(ds-mod/set-inference-target :mpg)))

(def model (ml/train ds {:model-type :metamorph.ml/ols}))

(ml/glance model)
;; => _unnamed [1 3]:
;; | :totss | :adj.r.squared | :rss |
;; |-------------:|---------------:|-------------:|
;; | 1126.0471875 | 0.80664232 | 147.49443002 |
(ml/tidy model)
;; => _unnamed [11 3]:
;; | :term | :estimate | :std.error |
;; |-------|------------:|------------:|
;; | :mpg | 12.30337416 | 18.71788443 |
;; | :cyl | -0.11144048 | 1.04502336 |
;; | :disp | 0.01333524 | 0.01785750 |
;; | :hp | -0.02148212 | 0.02176858 |
;; | :drat | 0.78711097 | 1.63537307 |
;; | :wt | -3.71530393 | 1.89441430 |
;; | :qsec | 0.82104075 | 0.73084480 |
;; | :vs | 0.31776281 | 2.10450861 |
;; | :am | 2.52022689 | 2.05665055 |
;; | :gear | 0.65541302 | 1.49325996 |
;; | :carb | -0.19941925 | 0.82875250 |


(ml/augment model ds)
;; => _unnamed [32 12]:
;; | :mpg | :cyl | :disp | :hp | :drat | :wt | :qsec | :vs | :am | :gear | :carb | :.residuals |
;; |-----:|-----:|------:|----:|------:|------:|------:|----:|----:|------:|------:|------------:|
;; | 21.0 | 6 | 160.0 | 110 | 3.90 | 2.620 | 16.46 | 0 | 1 | 4 | 4 | -1.59950576 |
;; | 21.0 | 6 | 160.0 | 110 | 3.90 | 2.875 | 17.02 | 0 | 1 | 4 | 4 | -1.11188608 |
;; | 22.8 | 4 | 108.0 | 93 | 3.85 | 2.320 | 18.61 | 1 | 1 | 4 | 1 | -3.45064408 |
;; | 21.4 | 6 | 258.0 | 110 | 3.08 | 3.215 | 19.44 | 1 | 0 | 3 | 1 | 0.16259545 |
;; | 18.7 | 8 | 360.0 | 175 | 3.15 | 3.440 | 17.02 | 0 | 0 | 3 | 2 | 1.00656597 |
;; | 18.1 | 6 | 225.0 | 105 | 2.76 | 3.460 | 20.22 | 1 | 0 | 3 | 1 | -2.28303904 |
;; | 14.3 | 8 | 360.0 | 245 | 3.21 | 3.570 | 15.84 | 0 | 0 | 3 | 4 | -0.08625625 |
;; | 24.4 | 4 | 146.7 | 62 | 3.69 | 3.190 | 20.00 | 1 | 0 | 4 | 2 | 1.90398812 |
;; | 22.8 | 4 | 140.8 | 95 | 3.92 | 3.150 | 22.90 | 1 | 0 | 4 | 2 | -1.61908990 |
;; | 19.2 | 6 | 167.6 | 123 | 3.92 | 3.440 | 18.30 | 1 | 0 | 4 | 4 | 0.50097006 |
;; | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
;; | 15.5 | 8 | 318.0 | 150 | 2.76 | 3.520 | 16.87 | 0 | 0 | 3 | 2 | -1.44305322 |
;; | 15.2 | 8 | 304.0 | 150 | 3.15 | 3.435 | 17.30 | 0 | 0 | 3 | 2 | -2.53218150 |
;; | 13.3 | 8 | 350.0 | 245 | 3.73 | 3.840 | 15.41 | 0 | 0 | 3 | 4 | -0.00602198 |
;; | 19.2 | 8 | 400.0 | 175 | 3.08 | 3.845 | 17.05 | 0 | 0 | 3 | 2 | 2.50832101 |
;; | 27.3 | 4 | 79.0 | 66 | 4.08 | 1.935 | 18.90 | 1 | 1 | 4 | 1 | -0.99346869 |
;; | 26.0 | 4 | 120.3 | 91 | 4.43 | 2.140 | 16.70 | 0 | 1 | 5 | 2 | -0.15295396 |
;; | 30.4 | 4 | 95.1 | 113 | 3.77 | 1.513 | 16.90 | 1 | 1 | 5 | 2 | 2.76372742 |
;; | 15.8 | 8 | 351.0 | 264 | 4.22 | 3.170 | 14.50 | 0 | 1 | 5 | 4 | -3.07004080 |
;; | 19.7 | 6 | 145.0 | 175 | 3.62 | 2.770 | 15.50 | 0 | 1 | 5 | 6 | 0.00617185 |
;; | 15.0 | 8 | 301.0 | 335 | 3.54 | 3.570 | 14.60 | 0 | 1 | 5 | 8 | 1.05888162 |
;; | 21.4 | 4 | 121.0 | 109 | 4.11 | 2.780 | 18.60 | 1 | 1 | 4 | 2 | -2.96826768 |










;; ----------------------------------------------------------------
;; R
;; m=lm(mpg ~ .,mtcars)
;; > t(glance(m))
;; [,1]
;; r.squared 0.8690157644778
;; adj.r.squared 0.8066423189910
;; sigma 2.6501970278655
;; statistic 13.9324636902088
;; p.value 0.0000003793152
;; df 10.0000000000000
;; logLik -69.8549052172399
;; AIC 163.7098104344797
;; BIC 181.2986412680764
;; deviance 147.4944300166508
;; df.residual 21.0000000000000
;; nobs 32.0000000000000

;; > tidy(m)
;; # A tibble: 11 × 5
;; term estimate std.error statistic p.value
;; <chr> <dbl> <dbl> <dbl> <dbl>
;; 1 (Intercept) 12.3 18.7 0.657 0.518
;; 2 cyl -0.111 1.05 -0.107 0.916
;; 3 disp 0.0133 0.0179 0.747 0.463
;; 4 hp -0.0215 0.0218 -0.987 0.335
;; 5 drat 0.787 1.64 0.481 0.635
;; 6 wt -3.72 1.89 -1.96 0.0633
;; 7 qsec 0.821 0.731 1.12 0.274
;; 8 vs 0.318 2.10 0.151 0.881
;; 9 am 2.52 2.06 1.23 0.234
;; 10 gear 0.655 1.49 0.439 0.665
;; 11 carb -0.199 0.829 -0.241 0.812

;; print(augment(m),width=Inf)
;; # A tibble: 32 × 18
;; .rownames mpg cyl disp hp drat wt qsec vs am gear carb .fitted .resid .hat .sigma .cooksd .std.resid
;; <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
;; 1 Mazda RX4 21 6 160 110 3.9 2.62 16.5 0 1 4 4 22.6 -1.60 0.303 2.68 0.0206 -0.723
;; 2 Mazda RX4 Wag 21 6 160 110 3.9 2.88 17.0 0 1 4 4 22.1 -1.11 0.290 2.70 0.00922 -0.498
;; 3 Datsun 710 22.8 4 108 93 3.85 2.32 18.6 1 1 4 1 26.3 -3.45 0.239 2.57 0.0635 -1.49
;; 4 Hornet 4 Drive 21.4 6 258 110 3.08 3.22 19.4 1 0 3 1 21.2 0.163 0.228 2.72 0.000131 0.0698
;; 5 Hornet Sportabout 18.7 8 360 175 3.15 3.44 17.0 0 0 3 2 17.7 1.01 0.200 2.70 0.00408 0.425
;; 6 Valiant 18.1 6 225 105 2.76 3.46 20.2 1 0 3 1 20.4 -2.28 0.282 2.65 0.0370 -1.02
;; 7 Duster 360 14.3 8 360 245 3.21 3.57 15.8 0 0 3 4 14.4 -0.0863 0.326 2.72 0.0000691 -0.0396
;; 8 Merc 240D 24.4 4 147. 62 3.69 3.19 20 1 0 4 2 22.5 1.90 0.330 2.67 0.0345 0.878
;; 9 Merc 230 22.8 4 141. 95 3.92 3.15 22.9 1 0 4 2 24.4 -1.62 0.742 2.62 0.379 -1.20
;; 10 Merc 280 19.2 6 168. 123 3.92 3.44 18.3 1 0 4 4 18.7 0.501 0.429 2.71 0.00428 0.250
Loading