Skip to content

Commit

Permalink
support post processing initial step
Browse files Browse the repository at this point in the history
  • Loading branch information
salmma committed May 23, 2024
1 parent 28a7bd2 commit 786bd31
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 9 deletions.
38 changes: 35 additions & 3 deletions app/implementation_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,19 @@ def prepare_code_from_data(data, input_params):
return circuit


def prepare_code_from_url(url, input_params, bearer_token: str = ""):
def prepare_code_from_url(url, input_params, bearer_token: str = "", post_processing=False):
"""Get implementation code from URL. Set input parameters into implementation. Return circuit."""
try:
impl = _download_code(url, bearer_token)
except (error.HTTPError, error.URLError):
return None

circuit = prepare_code_from_data(impl, input_params)
return circuit
if not post_processing:
circuit = prepare_code_from_data(impl, input_params)
return circuit
else:
result = prepare_post_processing_code_from_data(impl, input_params)
return result


def prepare_code_from_qasm(qasm):
Expand All @@ -86,6 +90,34 @@ def prepare_code_from_qasm_url(url, bearer_token: str = ""):
return prepare_code_from_qasm(impl)


def prepare_post_processing_code_from_data(data, input_params):
"""Get implementation code from data. Set input parameters into implementation. Return circuit."""
temp_dir = tempfile.mkdtemp()
with open(os.path.join(temp_dir, "__init__.py"), "w") as f:
f.write("")
with open(os.path.join(temp_dir, "downloaded_code.py"), "w") as f:
f.write(data)
sys.path.append(temp_dir)
try:
import downloaded_code

# deletes every attribute from downloaded_code, except __name__, because importlib.reload
# doesn't reset the module's global variables
for attr in dir(downloaded_code):
if attr != "__name__":
delattr(downloaded_code, attr)

reload(downloaded_code)
if 'post_processing' in dir(downloaded_code):
result = downloaded_code.post_processing(**input_params)
finally:
sys.path.remove(temp_dir)
shutil.rmtree(temp_dir, ignore_errors=True)
if not result:
raise ValueError
return result


def _download_code(url: str, bearer_token: str = "") -> str:
req = request.Request(url)

Expand Down
31 changes: 31 additions & 0 deletions app/post_processing_result_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# ******************************************************************************
# Copyright (c) 2024 University of Stuttgart
#
# See the NOTICE file(s) distributed with this work for additional
# information regarding copyright ownership.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************

from app import db


class Post_Processing_Result(db.Model):
id = db.Column(db.String(36), primary_key=True)
execution_result_id = db.Column(db.String(36), db.ForeignKey('result.id'))
generated_circuit_id = db.Column(db.String(36), db.ForeignKey('generated_circuit.id'))
result = db.Column(db.String(1200), default="")
complete = db.Column(db.Boolean, default=False)

def __repr__(self):
return 'Post_Processing_Result {}'.format(self.result)
32 changes: 26 additions & 6 deletions app/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import base64
import datetime
import json
import uuid

from qiskit import transpile, QuantumCircuit, Aer
from qiskit.transpiler.exceptions import TranspilerError
Expand All @@ -30,6 +31,7 @@
from app.NumpyEncoder import NumpyEncoder
from app.benchmark_model import Benchmark
from app.generated_circuit_model import Generated_Circuit
from app.post_processing_result_model import Post_Processing_Result
from app.result_model import Result


Expand Down Expand Up @@ -69,14 +71,14 @@ def generate(impl_url, impl_data, impl_language, input_params, bearer_token):
generated_circuit_code)

generated_circuit_object.input_params = json.dumps(input_params)
print(generated_circuit_object.input_params)
app.logger.info(f"Received input params for circuit generation: {generated_circuit_object.input_params}")
generated_circuit_object.complete = True
db.session.commit()


def execute(provider, impl_url, impl_data, impl_language, transpiled_qasm, input_params, token, access_key_aws,
secret_access_key_aws, qpu_name, optimization_level, noise_model, only_measurement_errors, shots,
bearer_token, qasm_string, **kwargs):
def execute(correlation_id, provider, impl_url, impl_data, impl_language, transpiled_qasm, input_params, token,
access_key_aws, secret_access_key_aws, qpu_name, optimization_level, noise_model, only_measurement_errors,
shots, bearer_token, qasm_string, **kwargs):
"""Create database entry for result. Get implementation code, prepare it, and execute it. Save result in db"""
app.logger.info("Starting execute task...")
job = get_current_job()
Expand All @@ -100,14 +102,14 @@ def execute(provider, impl_url, impl_data, impl_language, transpiled_qasm, input
else:
if qasm_string:
circuits = [implementation_handler.prepare_code_from_qasm(qasm) for qasm in qasm_string]
elif impl_url:
elif impl_url and not correlation_id:
if impl_language.lower() == 'openqasm':
# list of circuits
circuits = [implementation_handler.prepare_code_from_qasm_url(url, bearer_token) for url in impl_url]
else:
circuits = [implementation_handler.prepare_code_from_url(url, input_params, bearer_token) for url in
impl_url]
elif impl_data:
elif impl_data and not correlation_id:
impl_data = [base64.b64decode(data.encode()).decode() for data in impl_data]
if impl_language.lower() == 'openqasm':
circuits = [implementation_handler.prepare_code_from_qasm(data) for data in impl_data]
Expand Down Expand Up @@ -168,6 +170,24 @@ def execute(provider, impl_url, impl_data, impl_language, transpiled_qasm, input
result.result = json.dumps(job_result['counts'])
result.complete = True
db.session.commit()

# implementation contains post processing of execution results that has to be executed
if correlation_id and (impl_url or impl_data):
# TODO create new post processing result object
post_processing_result = Post_Processing_Result(id=str(uuid.uuid4()))
# prepare input data containing execution results and initial input params for generating the circuit
generated_circuit = Generated_Circuit.query.get(correlation_id)
input_params_for_post_processing = generated_circuit.input_params
input_params_for_post_processing['counts'] = json.dumps(job_result['counts'])

if impl_url:
result = implementation_handler.prepare_code_from_url(impl_url,
input_params=input_params_for_post_processing,
bearer_token=bearer_token, post_processing=True)
elif impl_data:
result = implementation_handler.prepare_post_processing_code_from_data(data=impl_data,
input_params=input_params_for_post_processing) # TODO save result in postprocessing result object

else:
result = Result.query.get(job.get_id())
result.result = json.dumps({'error': 'execution failed'})
Expand Down

0 comments on commit 786bd31

Please sign in to comment.