Skip to content

Commit

Permalink
Implemented distance metric generation function. Added Mahalanobis di…
Browse files Browse the repository at this point in the history
…stance metric (#14)
  • Loading branch information
niekdt committed Nov 4, 2021
1 parent 5b1abfe commit e2d291b
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 5 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ importFrom(stats,fitted)
importFrom(stats,formula)
importFrom(stats,getCall)
importFrom(stats,logLik)
importFrom(stats,mahalanobis)
importFrom(stats,model.frame)
importFrom(stats,model.matrix)
importFrom(stats,nobs)
Expand Down
139 changes: 139 additions & 0 deletions R/metricsInternal.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,121 @@ getInternalMetricDefinition = function(name) {
}
}

.defineInternalDistanceMetric = function(
name,
type = c('traj', 'fitted'),
distanceFun,
clusterAggregationFun = weighted.mean,
assertNonEmpty = TRUE,
assertNonSolitary = FALSE,
assertNonIdentical = FALSE,
...
) {
type = match.arg(type[1], c('traj', 'fitted'))
if (type != 'traj') {
fullName = paste(name, type, sep = '.')
} else {
# traj is the default
fullName = name
}

assert_that(
is.function(distanceFun),
is.function(clusterAggregationFun)
)

trajFun = switch(type,
traj = trajectories,
fitted = fittedTrajectories,
)

fun = function(m) {
nTimes = length(time(m))
dfTraj = trajFun(m)
trajMat = dcastRepeatedMeasures(
dfTraj,
id = idVariable(m),
time = timeVariable(m),
response = responseVariable(m)
)
assert_that(ncol(trajMat) == nTimes)
trajMatList = lapply(split(trajMat, trajectoryAssignments(m)), matrix, ncol = nTimes)

dtClus = clusterTrajectories(m, at = time(m))
clusMat = dcastRepeatedMeasures(
dtClus,
id = 'Cluster',
time = timeVariable(m),
response = responseVariable(m)
)
clusVecList = split(clusMat, row(clusMat))
assert_that(all(lengths(clusVecList) == nTimes))

emptyMask = vapply(trajMatList, nrow, FUN.VALUE = 0) == 0 & !assertNonEmpty
solitaryMask = vapply(trajMatList, nrow, FUN.VALUE = 0) == 1 & !assertNonSolitary
identicalMask = vapply(
trajMatList,
function(x) all(x[1,] == t(x)),
FUN.VALUE = FALSE
) & !assertNonIdentical

if (any(emptyMask)) {
warning(
sprintf(
'Cannot compute distance metric "%s" for cluster(s) "%s": No trajectories assigned to the cluster.',
fullName,
paste0(clusterNames(m)[emptyMask], collapse = '", "')
)
)
}

if (any(solitaryMask)) {
warning(
sprintf(
'Cannot compute distance metric "%s" for cluster(s) "%s": Only 1 trajectory assigned to the cluster.',
fullName,
paste0(clusterNames(m)[solitaryMask], collapse = '", "')
)
)
}

if (any(identicalMask)) {
warning(
sprintf(
'Cannot compute distance metric "%s" for cluster(s) "%s": All trajectories are identical (i.e., zero covariance).',
fullName,
paste0(clusterNames(m)[identicalMask], collapse = '", "')
)
)
}

validMask = !emptyMask & !solitaryMask & !identicalMask

if (!any(validMask)) {
return(as.numeric(NA))
}

clusDistances = Map(
distanceFun,
trajMatList[validMask],
clusVecList[validMask],
clusterNames(m)[validMask]
)
clusterAggregationFun(unlist(clusDistances), w = clusterProportions(m)[validMask])
}

defineInternalMetric(fullName, fun = fun, ...)
}

#' @title Define the distance metrics for multiple types at once
#' @keywords internal
.defineInternalDistanceMetrics = Vectorize(
FUN = .defineInternalDistanceMetric,
vectorize.args = 'type',
SIMPLIFY = FALSE
)


# Internal metric definitions ####
#' @importFrom stats AIC
intMetricsEnv$AIC = AIC
Expand Down Expand Up @@ -94,6 +209,30 @@ intMetricsEnv$MAE = function(m) {
residuals(m) %>% abs %>% mean
}

# . Mahalanobis distance ####
#' @importFrom stats mahalanobis
.defineInternalDistanceMetric(
name = 'Mahalanobis',
type = 'traj',
distanceFun = function(trajClusMat, clusVec, clusName) {
vcovMat = cov(trajClusMat)
if (det(vcovMat) == 0) {
warning(
sprintf(
'Cannot compute Mahalanobis distance for cluster "%s": covariance matrix is singular',
clusName
)
)
as.numeric(NA)
} else {
mean(mahalanobis(trajClusMat, center = clusVec, cov = vcovMat))
}
},
clusterAggregationFun = weighted.mean,
assertNonSolitary = TRUE,
assertNonIdentical = TRUE
)

intMetricsEnv$MSE = function(m) {
mean(residuals(m) ^ 2)
}
Expand Down
21 changes: 21 additions & 0 deletions man/dot-defineInternalDistanceMetrics.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions tests/testthat/setup-models.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ kml2 = latrend(lcMethodTestKML(), data = testLongData, nClusters = 2)
kml3 = latrend(lcMethodTestKML(), data = testLongData, nClusters = 3)
kml4 = latrend(lcMethodTestKML(), data = testLongData, nClusters = 4)

gmm1 = latrend(lcMethodTestLcmmGMM(), testLongData, nClusters = 1)
gmm2 = gmm = latrend(lcMethodTestLcmmGMM(), testLongData, nClusters = 2)
gmm3 = latrend(lcMethodTestLcmmGMM(), testLongData, nClusters = 3)
14 changes: 9 additions & 5 deletions tests/testthat/test-metrics.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ rngReset()
internalMetrics = getInternalMetricNames() %>% setdiff('deviance')

test_that('two clusters', {
for(name in internalMetrics) {
value = metric(kml2, name=name)
for (name in internalMetrics) {
value = metric(gmm2, name = name)
expect_is(value, 'numeric')
expect_true(is.finite(value), info=name)
expect_true(is.finite(value), info = name)
}
})

test_that('single cluster', {
for(name in internalMetrics) {
value = metric(kml1, name=name)
for (name in internalMetrics) {
value = metric(gmm1, name = name)
expect_is(value, 'numeric')
expect_length(value, 1)
}
Expand Down Expand Up @@ -59,6 +59,10 @@ test_that('WMAE', {
expect_gte(wmaeFuzzy, maeFuzzy)
})

test_that('Mahalanobis', {
expect_true('Mahalanobis' %in% getInternalMetricNames())
})

test_that('missing metric', {
expect_warning(met <- metric(kml2, '.MISSING'))
expect_true(is.na(met))
Expand Down

0 comments on commit e2d291b

Please sign in to comment.