Skip to content

Commit

Permalink
Merge pull request #64 from openclimatefix/issue/get-site-group-details
Browse files Browse the repository at this point in the history
Issue/get site group details
  • Loading branch information
rachel-labri-tipton authored Sep 13, 2023
2 parents d978a8e + edbe834 commit a9b02ec
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 47 deletions.
119 changes: 81 additions & 38 deletions src/sites_toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,17 @@
get_all_sites,
get_user_by_email,
get_site_by_uuid,
get_site_group_by_name,
)
from get_data import (
get_all_users,
get_all_site_groups,
get_site_by_client_site_id,
get_site_by_client_site_id,
update_user_site_group,
)


import plotly.graph_objects as go

# get details for one user
# get details for one user
def get_user_details(session, email):
"""Get the user details from the database"""
user_details = get_user_by_email(session=session, email=email)
Expand All @@ -30,39 +29,64 @@ def get_user_details(session, email):
]
return user_sites, user_site_group, user_site_count


# get details for one site
def get_site_details(session, site_uuid):
"""Get the site details for one site"""
site = get_site_by_uuid(session=session, site_uuid=site_uuid)
site_details = {"site_uuid": str(site.site_uuid),
"client_site_id": str(site.client_site_id),
"client_site_name": str(site.client_site_name),
"site_group_names" : [site_group.site_group_name for site_group in site.site_groups],
"latitude": str(site.latitude),
"longitude": str(site.longitude),
"DNO": str(site.dno),
"GSP": str(site.gsp),
"tilt": str(site.tilt),
"orientation": str(site.orientation),
"capacity": (f'{site.capacity_kw} kw'),
"date_added": (site.created_utc.strftime("%Y-%m-%d"))}
return site_details

# user selects site by site_uuid or client_site_id
"""Get the site details for one site"""
site = get_site_by_uuid(session=session, site_uuid=site_uuid)
site_details = {
"site_uuid": str(site.site_uuid),
"client_site_id": str(site.client_site_id),
"client_site_name": str(site.client_site_name),
"site_group_names": [
site_group.site_group_name for site_group in site.site_groups
],
"latitude": str(site.latitude),
"longitude": str(site.longitude),
"DNO": str(site.dno),
"GSP": str(site.gsp),
"tilt": str(site.tilt),
"orientation": str(site.orientation),
"capacity": (f"{site.capacity_kw} kw"),
"date_added": (site.created_utc.strftime("%Y-%m-%d")),
}
return site_details


# select site by site_uuid or client_site_id
def select_site_id(dbsession, query_method):
if query_method == "site_uuid":
site_uuids = [str(site.site_uuid) for site in get_all_sites(session=dbsession)]
selected_uuid = st.selectbox("Sites by site_uuid", site_uuids)
elif query_method == "client_site_id":
client_site_ids= [str(site.client_site_id) for site in get_all_sites(session=dbsession)]
client_site_id= st.selectbox("Sites by client_site_id", client_site_ids)
site = get_site_by_client_site_id(session=dbsession, client_site_id = client_site_id)
selected_uuid = str(site.site_uuid)
elif query_method not in ["site_uuid", "client_site_id"]:
raise ValueError("Please select a valid query_method.")
return selected_uuid
"""Select site by site_uuid or client_site_id"""
if query_method == "site_uuid":
site_uuids = [str(site.site_uuid) for site in get_all_sites(session=dbsession)]
selected_uuid = st.selectbox("Sites by site_uuid", site_uuids)
elif query_method == "client_site_id":
client_site_ids = [
str(site.client_site_id) for site in get_all_sites(session=dbsession)
]
client_site_id = st.selectbox("Sites by client_site_id", client_site_ids)
site = get_site_by_client_site_id(
session=dbsession, client_site_id=client_site_id
)
selected_uuid = str(site.site_uuid)
elif query_method not in ["site_uuid", "client_site_id"]:
raise ValueError("Please select a valid query_method.")
return selected_uuid


# get details for one site group
def get_site_group_details(session, site_group_name):
"""Get the site group details from the database"""
site_group_uuid = get_site_group_by_name(
session=session, site_group_name=site_group_name
)
site_group_sites = [
{"site_uuid": str(site.site_uuid), "client_site_id": str(site.client_site_id)}
for site in site_group_uuid.sites]
site_group_users = [user.email for user in site_group_uuid.users]
return site_group_sites, site_group_users


# sites toolbox page
def sites_toolbox_page():
st.markdown(
f'<h1 style="color:#FFD053;font-size:48px;">{"OCF Dashboard"}</h1>',
Expand All @@ -83,14 +107,15 @@ def sites_toolbox_page():
# get the user details
users = get_all_users(session=session)
user_list = [user.email for user in users]

site_groups = get_all_site_groups(session=session)
site_groups = [site_groups.site_group_name for site_groups in site_groups]

st.markdown(
f'<h1 style="color:#63BCAF;font-size:32px;">{"Get User Details"}</h1>',
unsafe_allow_html=True,
)
email = st.selectbox("Enter email of user you want to know about.", user_list)
# getting user details
# getting user details
if st.button("Get user details"):
user_sites, user_site_group, user_site_count = get_user_details(
session=session, email=email
Expand All @@ -107,20 +132,38 @@ def sites_toolbox_page():
)
if st.button("Close user details"):
st.empty()
# getting site details

# getting site details
st.markdown(
f'<h1 style="color:#63BCAF;font-size:32px;">{"Get Site Details"}</h1>',
unsafe_allow_html=True,
)
query_method = st.radio("Select site by", ("site_uuid", "client_site_id"))

site_id = select_site_id(dbsession=session, query_method=query_method)

if st.button("Get site details"):
site_details = get_site_details(session=session, site_uuid=site_id)
site_id = site_details["client_site_id"] if query_method == "client_site_id" else site_details["site_uuid"]
site_id = (
site_details["client_site_id"]
if query_method == "client_site_id"
else site_details["site_uuid"]
)
st.write("Here are the site details for site", site_id, ":", site_details)
if st.button("Close site details"):
st.empty()

# getting site group details
st.markdown(
f'<h1 style="color:#63BCAF;font-size:32px;">{"Get Site Group Details"}</h1>',
unsafe_allow_html=True,
)
site_group_name = st.selectbox("Enter the site group name.", site_groups)
if st.button("Get site group details"):
site_group_sites, site_group_users = get_site_group_details(session=session, site_group_name=site_group_name)
st.write("Site group", site_group_name, "contains the following", len(site_group_sites), "sites:", site_group_sites)
st.write("The following", len(site_group_users), "users are part of this group:", site_group_users)
if st.button("Close site group details"):
st.empty()


19 changes: 12 additions & 7 deletions tests/test_get_data.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
"""tests for get_data.py"""
from get_data import get_all_users
from get_data import get_all_users, get_all_site_groups
from pvsite_datamodel.read import get_all_sites

#get all users
def test_get_all_users(db_session):
users = get_all_users(session=db_session)
# assert
assert len(users) == 0

# get all site groups
# def test_get_all_sites(db_session):
# sites = get_all_sites(session=db_session)
# # assert
# assert len(sites) == 0

# get all sites
def test_get_all_sites(db_session):
sites = get_all_sites(session=db_session)
# assert
assert len(sites) == 0

# get all site groups
def test_get_all_site_groups(db_session):
site_groups = get_all_site_groups(session=db_session)
# assert
assert len(site_groups) == 0

# update user site group
19 changes: 17 additions & 2 deletions tests/test_sites_toolbox.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Test the toolbox functions"""
from sites_toolbox import get_user_details, get_site_details, select_site_id
from sites_toolbox import get_user_details, get_site_details, select_site_id, get_site_group_details
from pvsite_datamodel.write.user_and_site import make_site, make_site_group, make_user

def test_get_user_details(db_session):
Expand Down Expand Up @@ -49,4 +49,19 @@ def test_select_site_id(db_session):
assert site_uuid == str(site.site_uuid)

site_uuid = select_site_id(dbsession=db_session, query_method="client_site_id")
assert site_uuid == str(site.site_uuid)
assert site_uuid == str(site.site_uuid)

# test for get_site_group_details
def test_get_site_group_details(db_session):
"""Test the get site group details function"""
site_group = make_site_group(db_session=db_session)
site_1 = make_site(db_session=db_session, ml_id=1)
site_2 = make_site(db_session=db_session, ml_id=2)
site_group.sites.append(site_1)
site_group.sites.append(site_2)

site_group_sites, site_group_users = get_site_group_details(session=db_session, site_group_name="test_site_group")

assert site_group_sites == [{"site_uuid": str(site.site_uuid), "client_site_id": str(site.client_site_id)}for site in site_group.sites]
assert site_group_users == [user.email for user in site_group.users]

0 comments on commit a9b02ec

Please sign in to comment.