Skip to content

Commit

Permalink
Add multi_switch strategy (#2409)
Browse files Browse the repository at this point in the history
* feat: add multi_switch strategy

* chore: add tests and config flow

* chore: commit WIP

* chore: commit WIP

* chore: commit WIP

* chore: commit WIP

* chore: commit WIP

* chore: commit WIP

* chore: commit WIP

* chore: commit WIP

* chore: commit WIP

* chore: commit WIP

* chore: commit WIP

* feat: add profile
  • Loading branch information
bramstroker authored Jul 26, 2024
1 parent cb1ac0c commit d815509
Show file tree
Hide file tree
Showing 17 changed files with 462 additions and 104 deletions.
73 changes: 64 additions & 9 deletions custom_components/powercalc/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from homeassistant.const import (
CONF_ATTRIBUTE,
CONF_DEVICE,
CONF_ENTITIES,
CONF_ENTITY_ID,
CONF_NAME,
CONF_UNIQUE_ID,
Expand Down Expand Up @@ -54,12 +55,14 @@
CONF_MIN_POWER,
CONF_MODE,
CONF_MODEL,
CONF_MULTI_SWITCH,
CONF_MULTIPLY_FACTOR,
CONF_MULTIPLY_FACTOR_STANDBY,
CONF_ON_TIME,
CONF_PLAYBOOK,
CONF_PLAYBOOKS,
CONF_POWER,
CONF_POWER_OFF,
CONF_POWER_TEMPLATE,
CONF_REPEAT,
CONF_SELF_USAGE_INCLUDED,
Expand Down Expand Up @@ -88,6 +91,7 @@
)
from .discovery import get_power_profile_by_source_entity
from .errors import ModelNotSupportedError, StrategyConfigurationError
from .helpers import get_or_create_unique_id
from .power_profile.factory import get_power_profile
from .power_profile.library import ModelInfo, ProfileLibrary
from .power_profile.power_profile import DOMAIN_DEVICE_TYPE, DeviceType, PowerProfile
Expand Down Expand Up @@ -117,6 +121,7 @@ class Steps(StrEnum):
VIRTUAL_POWER = "virtual_power"
FIXED = "fixed"
LINEAR = "linear"
MULTI_SWITCH = "multi_switch"
PLAYBOOK = "playbook"
WLED = "wled"
POWER_ADVANCED = "power_advanced"
Expand Down Expand Up @@ -233,6 +238,7 @@ class Steps(StrEnum):
options=[
CalculationStrategy.FIXED,
CalculationStrategy.LINEAR,
CalculationStrategy.MULTI_SWITCH,
CalculationStrategy.PLAYBOOK,
CalculationStrategy.WLED,
CalculationStrategy.LUT,
Expand Down Expand Up @@ -265,6 +271,16 @@ class Steps(StrEnum):
},
)

SCHEMA_POWER_MULTI_SWITCH = vol.Schema(
{
vol.Required(CONF_ENTITIES): selector.EntitySelector(
selector.EntitySelectorConfig(domain=Platform.SWITCH, multiple=True),
),
vol.Required(CONF_POWER): vol.Coerce(float),
vol.Required(CONF_POWER_OFF): vol.Coerce(float),
},
)

SCHEMA_POWER_PLAYBOOK = vol.Schema(
{
vol.Optional(CONF_PLAYBOOKS): selector.ObjectSelector(),
Expand Down Expand Up @@ -398,6 +414,8 @@ def create_strategy_schema(self, strategy: str, source_entity_id: str) -> vol.Sc
return self.create_schema_linear(source_entity_id)
if strategy == CalculationStrategy.PLAYBOOK:
return SCHEMA_POWER_PLAYBOOK
if strategy == CalculationStrategy.MULTI_SWITCH:
return self.create_schema_multi_switch()
if strategy == CalculationStrategy.WLED:
return SCHEMA_POWER_WLED
return vol.Schema({})
Expand Down Expand Up @@ -480,6 +498,16 @@ def create_schema_linear(source_entity_id: str) -> vol.Schema:
},
)

def create_schema_multi_switch(self) -> vol.Schema:
"""Create the config schema for multi switch strategy."""
schema = SCHEMA_POWER_MULTI_SWITCH
# Remove power options if we are in library flow as they are defined in the power profile
if self.is_library_flow:
del schema.schema[CONF_POWER]
del schema.schema[CONF_POWER_OFF]
schema = vol.Schema(schema.schema)
return schema

def create_schema_virtual_power(
self,
) -> vol.Schema:
Expand Down Expand Up @@ -705,7 +733,8 @@ async def async_step_integration_discovery(

self.source_entity_id = self.source_entity.entity_id
self.name = self.source_entity.name
unique_id = f"pc_{self.source_entity.unique_id}"

unique_id = get_or_create_unique_id(self.sensor_config, self.source_entity, self.power_profile)
await self.async_set_unique_id(unique_id)
self._abort_if_unique_id_configured()

Expand Down Expand Up @@ -765,18 +794,15 @@ async def async_step_virtual_power(
self.source_entity_id,
self.hass,
)
unique_id = user_input.get(CONF_UNIQUE_ID)
if not unique_id and self.source_entity_id != DUMMY_ENTITY_ID:
source_unique_id = self.source_entity.unique_id or self.source_entity_id
unique_id = f"pc_{source_unique_id}"

await self.async_set_unique_id(unique_id)
self._abort_if_unique_id_configured()

self.name = user_input.get(CONF_NAME) or self.source_entity.name
self.selected_sensor_type = SensorType.VIRTUAL_POWER
self.sensor_config.update(user_input)

unique_id = get_or_create_unique_id(self.sensor_config, self.source_entity, self.power_profile)
await self.async_set_unique_id(unique_id)
self._abort_if_unique_id_configured()

return await self.forward_to_strategy_step(selected_strategy)

return self.async_show_form(
Expand All @@ -797,6 +823,9 @@ async def forward_to_strategy_step(
if strategy == CalculationStrategy.LINEAR:
return await self.async_step_linear()

if strategy == CalculationStrategy.MULTI_SWITCH:
return await self.async_step_multi_switch()

if strategy == CalculationStrategy.PLAYBOOK:
return await self.async_step_playbook()

Expand Down Expand Up @@ -927,6 +956,25 @@ async def async_step_linear(
last_step=False,
)

async def async_step_multi_switch(
self,
user_input: dict[str, Any] | None = None,
) -> ConfigFlowResult:
"""Handle the flow for multi switch strategy."""
errors = {}
if user_input is not None:
self.sensor_config.update({CONF_MULTI_SWITCH: user_input})
errors = await self.validate_strategy_config()
if not errors:
return await self.async_step_power_advanced()

return self.async_show_form(
step_id=Steps.MULTI_SWITCH,
data_schema=self.create_schema_multi_switch(),
errors=errors,
last_step=False,
)

async def async_step_playbook(
self,
user_input: dict[str, Any] | None = None,
Expand Down Expand Up @@ -1096,9 +1144,16 @@ async def async_step_post_library(
if self.power_profile and self.power_profile.needs_fixed_config:
return await self.async_step_fixed()

if self.power_profile and self.power_profile.device_type == DeviceType.SMART_SWITCH:
if (
self.power_profile
and self.power_profile.device_type == DeviceType.SMART_SWITCH
and self.power_profile.calculation_strategy == CalculationStrategy.FIXED
):
return await self.async_step_smart_switch()

if self.power_profile and self.power_profile.calculation_strategy == CalculationStrategy.MULTI_SWITCH:
return await self.async_step_multi_switch()

return await self.async_step_power_advanced()

async def async_step_sub_profile(
Expand Down
1 change: 1 addition & 0 deletions custom_components/powercalc/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
CONF_POWER_SENSOR_FRIENDLY_NAMING = "power_sensor_friendly_naming"
CONF_POWER_SENSOR_PRECISION = "power_sensor_precision"
CONF_POWER = "power"
CONF_POWER_OFF = "power_off"
CONF_POWER_SENSOR_ID = "power_sensor_id"
CONF_POWER_TEMPLATE = "power_template"
CONF_PLAYBOOK = "playbook"
Expand Down
22 changes: 16 additions & 6 deletions custom_components/powercalc/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
CalculationStrategy,
)
from .errors import ModelNotSupportedError
from .helpers import get_or_create_unique_id
from .power_profile.factory import get_power_profile
from .power_profile.library import ModelInfo
from .power_profile.power_profile import DOMAIN_DEVICE_TYPE, PowerProfile
Expand Down Expand Up @@ -57,9 +58,16 @@ def __init__(self, hass: HomeAssistant, ha_config: ConfigType) -> None:
self.ha_config = ha_config
self.power_profiles: dict[str, PowerProfile | None] = {}
self.manually_configured_entities: list[str] | None = None
self.initialized_flows: set[str] = set()

async def start_discovery(self) -> None:
"""Start the discovery procedure."""

existing_entries = self.hass.config_entries.async_entries(DOMAIN)
for entry in existing_entries:
if entry.unique_id:
self.initialized_flows.add(entry.unique_id)

_LOGGER.debug("Start auto discovering entities")
entity_registry = er.async_get(self.hass)
for entity_entry in list(entity_registry.entities.values()):
Expand Down Expand Up @@ -221,12 +229,13 @@ def _init_entity_discovery(
extra_discovery_data: dict | None,
) -> None:
"""Dispatch the discovery flow for a given entity."""
existing_entries = [
entry
for entry in self.hass.config_entries.async_entries(DOMAIN)
if entry.unique_id in [source_entity.unique_id, f"pc_{source_entity.unique_id}"]
]
if existing_entries:

unique_id = get_or_create_unique_id({}, source_entity, power_profile)
unique_ids_to_check = [unique_id]
if unique_id.startswith("pc_"):
unique_ids_to_check.append(unique_id[3:])

if any(unique_id in self.initialized_flows for unique_id in unique_ids_to_check):
_LOGGER.debug(
"%s: Already setup with discovery, skipping new discovery",
source_entity.entity_id,
Expand All @@ -246,6 +255,7 @@ def _init_entity_discovery(
if extra_discovery_data:
discovery_data.update(extra_discovery_data)

self.initialized_flows.add(unique_id)
discovery_flow.async_create_flow(
self.hass,
DOMAIN,
Expand Down
26 changes: 26 additions & 0 deletions custom_components/powercalc/helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import decimal
import logging
import os.path
import uuid
from decimal import Decimal

from homeassistant.const import CONF_UNIQUE_ID
from homeassistant.helpers.template import Template
from homeassistant.helpers.typing import ConfigType

from custom_components.powercalc.common import SourceEntity
from custom_components.powercalc.const import DUMMY_ENTITY_ID, CalculationStrategy
from custom_components.powercalc.power_profile.power_profile import PowerProfile

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -35,3 +42,22 @@ def get_library_path(sub_path: str = "") -> str:
def get_library_json_path() -> str:
"""Get the path to the library.json file."""
return get_library_path("library.json")


def get_or_create_unique_id(sensor_config: ConfigType, source_entity: SourceEntity, power_profile: PowerProfile | None) -> str:
"""Get or create the unique id."""
unique_id = sensor_config.get(CONF_UNIQUE_ID)
if unique_id:
return str(unique_id)

# For multi-switch strategy we need to use the device id as unique id
# As we don't want to start a discovery for each switch entity
if power_profile and power_profile.calculation_strategy == CalculationStrategy.MULTI_SWITCH and source_entity.device_entry:
return f"pc_{source_entity.device_entry.id}"

if source_entity and source_entity.entity_id != DUMMY_ENTITY_ID:
source_unique_id = source_entity.unique_id or source_entity.entity_id
# Prefix with pc_ to avoid conflicts with other integrations
return f"pc_{source_unique_id}"

return str(uuid.uuid4())
30 changes: 21 additions & 9 deletions custom_components/powercalc/strategy/factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from decimal import Decimal

from homeassistant.const import CONF_CONDITION, CONF_ENTITIES
from homeassistant.core import HomeAssistant
from homeassistant.helpers import condition
Expand All @@ -14,6 +16,7 @@
CONF_MULTI_SWITCH,
CONF_PLAYBOOK,
CONF_POWER,
CONF_POWER_OFF,
CONF_POWER_TEMPLATE,
CONF_STANDBY_POWER,
CONF_STATES_POWER,
Expand Down Expand Up @@ -186,17 +189,26 @@ async def _create_sub_strategy(strategy_config: ConfigType) -> SubStrategy:

def _create_multi_switch(self, config: ConfigType, power_profile: PowerProfile | None) -> MultiSwitchStrategy:
"""Create instance of multi switch strategy."""
multi_switch_config = config.get(CONF_MULTI_SWITCH)
if multi_switch_config is None:
if power_profile and power_profile.get_strategy_config(CalculationStrategy.MULTI_SWITCH):
multi_switch_config = power_profile.get_strategy_config(CalculationStrategy.MULTI_SWITCH)
multi_switch_config: ConfigType = {}
if power_profile and power_profile.multi_switch_mode_config:
multi_switch_config = power_profile.multi_switch_mode_config
multi_switch_config.update(config.get(CONF_MULTI_SWITCH, {}))

if not multi_switch_config:
raise StrategyConfigurationError("No multi_switch configuration supplied")

entities: list[str] = multi_switch_config.get(CONF_ENTITIES, [])
if not entities:
raise StrategyConfigurationError("No switch entities supplied")

if multi_switch_config is None:
raise StrategyConfigurationError("No multi_switch configuration supplied")
on_power: Decimal | None = multi_switch_config.get(CONF_POWER)
off_power: Decimal | None = multi_switch_config.get(CONF_POWER_OFF)
if off_power is None or on_power is None:
raise StrategyConfigurationError("No power configuration supplied")

return MultiSwitchStrategy(
self._hass,
multi_switch_config.get(CONF_ENTITIES), # type: ignore
on_power=multi_switch_config.get(CONF_POWER), # type: ignore
off_power=config.get(CONF_STANDBY_POWER), # type: ignore
entities,
on_power=Decimal(on_power),
off_power=Decimal(off_power),
)
23 changes: 17 additions & 6 deletions custom_components/powercalc/strategy/multi_switch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@
import homeassistant.helpers.config_validation as cv
import voluptuous as vol
from homeassistant.components.switch import DOMAIN as SWITCH_DOMAIN
from homeassistant.const import CONF_ENTITIES, STATE_ON
from homeassistant.const import CONF_ENTITIES, STATE_ON, STATE_UNAVAILABLE
from homeassistant.core import HomeAssistant, State
from homeassistant.helpers.event import TrackTemplate

from custom_components.powercalc.const import CONF_POWER
from custom_components.powercalc.const import CONF_POWER, CONF_POWER_OFF, DUMMY_ENTITY_ID

from .strategy_interface import PowerCalculationStrategyInterface

CONFIG_SCHEMA = vol.Schema(
{
vol.Optional(CONF_POWER): vol.Any(vol.Coerce(float), cv.template),
vol.Optional(CONF_POWER): vol.Coerce(float),
vol.Optional(CONF_POWER_OFF): vol.Coerce(float),
vol.Required(CONF_ENTITIES): cv.entities_domain(SWITCH_DOMAIN),
},
)
Expand All @@ -40,11 +41,21 @@ def __init__(

async def calculate(self, entity_state: State) -> Decimal | None:
if self.known_states is None:
self.known_states = {entity_id: self.hass.states.get(entity_id) for entity_id in self.switch_entities}
self.known_states = {
entity_id: (state.state if (state := self.hass.states.get(entity_id)) else STATE_UNAVAILABLE) for entity_id in self.switch_entities
}

self.known_states[entity_state.entity_id] = entity_state.state
if entity_state.entity_id != DUMMY_ENTITY_ID:
self.known_states[entity_state.entity_id] = entity_state.state

return Decimal(sum(self.on_power if state == STATE_ON else self.off_power for state in self.known_states.values()))
def _get_power(state: str) -> Decimal:
if state == STATE_UNAVAILABLE:
return Decimal(0)
if state == STATE_ON:
return self.on_power
return self.off_power

return Decimal(sum(_get_power(state) for state in self.known_states.values()))

def get_entities_to_track(self) -> list[str | TrackTemplate]:
return self.switch_entities # type: ignore
Expand Down
Loading

0 comments on commit d815509

Please sign in to comment.