Skip to content

Commit

Permalink
DEV: Implement imputation for VCF features (#237)
Browse files Browse the repository at this point in the history
* Update python wrapper to include imputation strategy parameter

* Update scala API to pass imputation strategy to VCFFeatureSource

* Create functions to handle mode and zero imputation strategies

* Added imputation strategy to test cases

* Added imputation strategy to FeatureSource cli

* Remove sparkPar from test cases due to changes in class signature

* Updated DefVariantToFeatureConverterTest to use zeros imputation
  • Loading branch information
NickEdwards7502 committed Oct 17, 2024
1 parent 5ad8cc0 commit b686d75
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 23 deletions.
21 changes: 16 additions & 5 deletions python/varspark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,26 @@ def __init__(self, ss, silent=False):
" /_/ \n"
)

@params(self=object, vcf_file_path=str, min_partitions=int)
def import_vcf(self, vcf_file_path, min_partitions=0):
"""Import features from a VCF file."""
@params(self=object, vcf_file_path=str, imputation_strategy=Nullable(str))
def import_vcf(self, vcf_file_path, imputation_strategy="none"):
"""Import features from a VCF file.
:param vcf_file_path String: The file path for the vcf file to import
:param imputation_strategy String:
The imputation strategy to use. Options for imputation include:
- none: No imputation will be performed. Missing values will be replaced with -1 (not recommended unless there are no missing values)
- mode: Missing values will be replaced with the most commonly occuring value among that feature. Recommended option
- zeros: Missing values will be replaced with zeros. Faster than mode imputation
"""
if imputation_strategy == "none":
print("WARNING: Imputation strategy is set to none - please ensure that there are no missing values in the data.")
return FeatureSource(
self._jvm,
self._vs_api,
self._jsql,
self.sql,
self._jvsc.importVCF(vcf_file_path, min_partitions),
self._jvsc.importVCF(vcf_file_path, imputation_strategy),
)

@params(
Expand All @@ -76,7 +87,7 @@ def import_vcf(self, vcf_file_path, min_partitions=0):
def import_covariates(self, cov_file_path, cov_types=None, transposed=False):
"""Import covariates from a CSV file.
:param cov_file_path: The file path for covariate csv file
:param cov_file_path String: The file path for covariate csv file
:param cov_types Dict[String]:
A dictionary specifying types for each covariate, where the key is the variable name
and the value is the type. The value can be one of the following:
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/au/csiro/variantspark/api/VSContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ class VSContext(val spark: SparkSession) extends SqlContextHolder {
* @param inputFile path to file or directory with VCF files to load
* @return FeatureSource loaded from the VCF file
*/
def importVCF(inputFile: String, sparkPar: Int = 0): FeatureSource = {
def importVCF(inputFile: String, imputationStrategy: String = "none"): FeatureSource = {
val vcfSource =
VCFSource(sc, inputFile)
// VCFSource(sc.textFile(inputFile, if (sparkPar > 0) sparkPar else sc.defaultParallelism))
VCFFeatureSource(vcfSource)
VCFFeatureSource(vcfSource, imputationStrategy)
}

/** Import features from a CSV file
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class CochranArmanCmd extends ArgsApp with SparkApp with Echoable with Logging w
VCFSource(sc.textFile(inputFile, if (sparkPar > 0) sparkPar else sc.defaultParallelism))
verbose(s"VCF Version: ${vcfSource.version}")
verbose(s"VCF Header: ${vcfSource.header}")
VCFFeatureSource(vcfSource)
VCFFeatureSource(vcfSource, imputationStrategy = "none")
}

def loadCSV(): CsvFeatureSource = {
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/au/csiro/variantspark/cli/FilterCmd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class FilterCmd extends ArgsApp with TestArgs with SparkApp {
logDebug(s"Running with filesystem: ${fs}, home: ${fs.getHomeDirectory}")

val vcfSource = VCFSource(sc.textFile(inputFile))
val source = VCFFeatureSource(vcfSource)
val source = VCFFeatureSource(vcfSource, imputationStrategy = "none")
val features = source.features.zipWithIndex().cache()
val featureCount = features.count()
println(s"No features: ${featureCount}")
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/au/csiro/variantspark/cli/VcfToLabels.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class VcfToLabels extends ArgsApp with SparkApp {
val version = vcfSource.version
println(header)
println(version)
val source = VCFFeatureSource(vcfSource)
val source = VCFFeatureSource(vcfSource, imputationStrategy = "none")
val columns = source.features.take(limit)
CSVUtils.withFile(new File(outputFile)) { writer =>
writer.writeRow("" :: columns.map(_.label).toList)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ object VCFFeatureSourceFactory {
val DEF_SEPARATOR: String = "_"
}

case class VCFFeatureSourceFactory(inputFile: String, isBiallelic: Option[Boolean],
separator: Option[String])
case class VCFFeatureSourceFactory(inputFile: String, imputationStrategy: Option[String],
isBiallelic: Option[Boolean], separator: Option[String])
extends FeatureSourceFactory with Echoable {
def createSource(sparkArgs: SparkArgs): FeatureSource = {
echo(s"Loading header from VCF file: ${inputFile}")
Expand All @@ -36,8 +36,8 @@ case class VCFFeatureSourceFactory(inputFile: String, isBiallelic: Option[Boolea
verbose(s"VCF Header: ${vcfSource.header}")

import VCFFeatureSourceFactory._
VCFFeatureSource(vcfSource, isBiallelic.getOrElse(DEF_IS_BIALLELIC),
separator.getOrElse(DEF_SEPARATOR))
VCFFeatureSource(vcfSource, imputationStrategy.getOrElse("none"),
isBiallelic.getOrElse(DEF_IS_BIALLELIC), separator.getOrElse(DEF_SEPARATOR))
}
}

Expand Down
35 changes: 30 additions & 5 deletions src/main/scala/au/csiro/variantspark/input/VCFFeatureSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import au.csiro.variantspark.data.StdFeature

trait VariantToFeatureConverter {
def convert(vc: VariantContext): Feature
def convertModeImputed(vc: VariantContext): Feature
def convertZeroImputed(vc: VariantContext): Feature
}

case class DefVariantToFeatureConverter(biallelic: Boolean = false, separator: String = "_")
Expand All @@ -20,6 +22,18 @@ case class DefVariantToFeatureConverter(biallelic: Boolean = false, separator: S
StdFeature.from(convertLabel(vc), BoundedOrdinalVariable(3), gts)
}

def convertModeImputed(vc: VariantContext): Feature = {
val gts = vc.getGenotypes.iterator().asScala.map(convertGenotype).toArray
val modeImputedGts = ModeImputationStrategy(noLevels = 3).impute(gts)
StdFeature.from(convertLabel(vc), BoundedOrdinalVariable(3), modeImputedGts)
}

def convertZeroImputed(vc: VariantContext): Feature = {
val gts = vc.getGenotypes.iterator().asScala.map(convertGenotype).toArray
val zeroImputedGts = ZeroImputationStrategy.impute(gts)
StdFeature.from(convertLabel(vc), BoundedOrdinalVariable(3), zeroImputedGts)
}

def convertLabel(vc: VariantContext): String = {

if (biallelic && !vc.isBiallelic) {
Expand All @@ -44,23 +58,34 @@ case class DefVariantToFeatureConverter(biallelic: Boolean = false, separator: S
}

def convertGenotype(gt: Genotype): Byte = {
if (!gt.isCalled || gt.isHomRef) 0 else if (gt.isHomVar || gt.isHetNonRef) 2 else 1
if (!gt.isCalled) Missing.BYTE_NA_VALUE
else if (gt.isHomRef) 0
else if (gt.isHomVar || gt.isHetNonRef) 2
else 1
}
}

class VCFFeatureSource(vcfSource: VCFSource, converter: VariantToFeatureConverter)
class VCFFeatureSource(vcfSource: VCFSource, converter: VariantToFeatureConverter,
imputationStrategy: String)
extends FeatureSource {
override lazy val sampleNames: List[String] =
vcfSource.header.getGenotypeSamples.asScala.toList
override def features: RDD[Feature] = {
val converterRef = converter
vcfSource.genotypes().map(converterRef.convert)
imputationStrategy match {
case "none" => vcfSource.genotypes().map(converterRef.convert)
case "mode" => vcfSource.genotypes().map(converterRef.convertModeImputed)
case "zeros" => vcfSource.genotypes().map(converterRef.convertZeroImputed)
case _ =>
throw new IllegalArgumentException(s"Unknown imputation strategy: $imputationStrategy")
}
}
}

object VCFFeatureSource {
def apply(vcfSource: VCFSource, biallelic: Boolean = false,
def apply(vcfSource: VCFSource, imputationStrategy: String, biallelic: Boolean = false,
separator: String = "_"): VCFFeatureSource = {
new VCFFeatureSource(vcfSource, DefVariantToFeatureConverter(biallelic, separator))
new VCFFeatureSource(vcfSource, DefVariantToFeatureConverter(biallelic, separator),
imputationStrategy)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ class DefVariantToFeatureConverterTest {
@Test
def testConvertsBialleicVariantCorrctly() {
val converter = DefVariantToFeatureConverter(true, ":")
val result = converter.convert(bialellicVC)
val result = converter.convertZeroImputed(bialellicVC)
assertEquals("chr1:10:T:A", result.label)
assertArrayEquals(expectedEncodedGenotype, result.valueAsByteArray)
}

@Test
def testConvertsMultialleicVariantCorrctly() {
val converter = DefVariantToFeatureConverter(false)
val result = converter.convert(multialleciVC)
val result = converter.convertZeroImputed(multialleciVC)
assertEquals("chr1_10_T_A|G", result.label)
assertArrayEquals(expectedEncodedGenotype, result.valueAsByteArray)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class CovariateReproducibilityTest extends SparkTest {
def testCovariateReproducibleResults() {
implicit val vsContext = VSContext(spark)
implicit val sqlContext = spark.sqlContext
val genotypes = vsContext.importVCF("data/chr22_1000.vcf", 3)
val genotypes = vsContext.importVCF("data/chr22_1000.vcf")
val optVariableTypes = new ArrayList[String](Arrays.asList("CONTINUOUS", "ORDINAL(2)",
"CONTINUOUS", "CONTINUOUS", "CONTINUOUS", "CONTINUOUS"))
val covariates =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class ReproducibilityTest extends SparkTest {
def testReproducibleResults() {
implicit val vsContext = VSContext(spark)
implicit val sqlContext = spark.sqlContext
val features = vsContext.importVCF("data/chr22_1000.vcf", 3)
val features = vsContext.importVCF("data/chr22_1000.vcf")
val label = vsContext.loadLabel("data/chr22-labels.csv", "22_16051249")
val params = RandomForestParams(seed = 13L)
val rfModel1 = RFModelTrainer.trainModel(features, label, params, 40, 20)
Expand Down

0 comments on commit b686d75

Please sign in to comment.