Skip to content

Commit

Permalink
Implemented color scale and legend for grouped metric plot (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
niekdt committed Jun 11, 2022
1 parent acd1f43 commit e6b24c3
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 13 deletions.
26 changes: 16 additions & 10 deletions R/models.R
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ plotMetric = function(
.loadOptionalPackage('ggplot2')

models = as.lcModels(models)
assert_that(length(models) > 0, msg = 'need at least 1 lcModel to plot')
assert_that(length(models) > 0L, msg = 'need at least 1 lcModel to plot')
assert_that(is.character(name), length(name) >= 1)

if (!missing(subset)) {
Expand All @@ -429,12 +429,15 @@ plotMetric = function(
as.data.table()

assert_that(
nrow(dtModels) == nrow(dtMetrics),
is.null(group) || has_name(dtModels, group)
nrow(dtModels) == nrow(dtMetrics)
)
assert_that(
length(group) == 0L || has_name(dtModels, group),
msg = 'plotMetric() group argument contains names which are not columns of `as.data.frame(models)`'
)

dtModelMetrics = cbind(dtModels, dtMetrics)
if (length(group) == 0) {
if (length(group) == 0L) {
dtModelMetrics[, .group := 'All']
} else {
dtModelMetrics[, .group := do.call(interaction, base::subset(dtModelMetrics, select = group))]
Expand All @@ -452,10 +455,13 @@ plotMetric = function(
setnames('.group', 'Group')
levels(dtgg$Metric) = name

p = ggplot2::ggplot(
data = dtgg,
mapping = ggplot2::aes_string(x = by, y = 'Value', group = 'Group')
)
if (length(group) == 0L) {
map = ggplot2::aes_string(x = by, y = 'Value')
} else {
map = ggplot2::aes_string(x = by, y = 'Value', group = 'Group', color = 'Group')
}

p = ggplot2::ggplot(data = dtgg, mapping = map)

if (is.numeric(dtModelMetrics[[by]]) || is.logical(dtModelMetrics[[by]])) {
p = p + ggplot2::geom_line()
Expand All @@ -464,12 +470,12 @@ plotMetric = function(

if (by == 'nClusters') {
p = p + ggplot2::scale_x_continuous(
breaks = seq(1, max(dtModelMetrics[[by]])),
breaks = seq(1L, max(dtModelMetrics[[by]])),
minor_breaks = NULL
)
}

if (length(name) == 1) {
if (is.scalar(name)) {
p = p + ggplot2::ylab(name)
} else {
p = p + ggplot2::ylab('Value') +
Expand Down
14 changes: 14 additions & 0 deletions tests/testthat/test-metrics.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,17 @@ test_that('define metric', {

expect_equal(getInternalMetricDefinition('.NEW'), force)
})

test_that('plot single model', {
expect_is(
plotMetric(testModel, 'WMAE'),
'gg'
)
})

test_that('plot single model, multiple metrics', {
expect_is(
plotMetric(testModel, c('WMAE', 'RMSE')),
'gg'
)
})
9 changes: 6 additions & 3 deletions tests/testthat/test-models.R
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,16 @@ test_that('plot subset with no results', {
test_that('plotMetric', {
skip_if_not_installed('ggplot2')

plotMetric(models, name='BIC', subset=.method == 'lmkm') %>%
plotMetric(models, name = 'BIC', subset = .method == 'lmkm') %>%
expect_is('gg')

plotMetric(models, name=c('logLik', 'BIC'), subset=.method == 'lmkm') %>%
plotMetric(models, name = c('logLik', 'BIC'), subset = .method == 'lmkm') %>%
expect_is('gg')

plotMetric(models, name=c('logLik', 'BIC'), by='nClusters', group=character()) %>%
plotMetric(models, name = c('logLik', 'BIC'), by = 'nClusters', group = character()) %>%
expect_is('gg')

plotMetric(models, name = c('WMAE', 'RMSE', 'BIC'), by = 'nClusters', group = '.method') %>%
expect_is('gg')
})

Expand Down

0 comments on commit e6b24c3

Please sign in to comment.