Skip to content

Commit

Permalink
fix(frontend): add missing composition serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
aPere3 authored and BourgerieQuentin committed Jun 18, 2024
1 parent 9fffe8b commit 13253d1
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ struct CompositionRule {
};

const std::vector<CompositionRule> DEFAULT_COMPOSITION_RULES = {};
const bool DEFAULT_COMPOSABLE = false;

struct Config {
double p_error;
Expand All @@ -124,24 +125,24 @@ struct Config {
uint32_t ciphertext_modulus_log;
uint32_t fft_precision;
std::vector<CompositionRule> composition_rules;
bool composable;
};

const Config DEFAULT_CONFIG = {
UNSPECIFIED_P_ERROR,
UNSPECIFIED_GLOBAL_P_ERROR,
DEFAULT_DISPLAY,
DEFAULT_STRATEGY,
DEFAULT_KEY_SHARING,
DEFAULT_MULTI_PARAM_STRATEGY,
DEFAULT_SECURITY,
DEFAULT_FALLBACK_LOG_NORM_WOPPBS,
DEFAULT_USE_GPU_CONSTRAINTS,
DEFAULT_ENCODING,
DEFAULT_CACHE_ON_DISK,
DEFAULT_CIPHERTEXT_MODULUS_LOG,
DEFAULT_FFT_PRECISION,
DEFAULT_COMPOSITION_RULES,
};
const Config DEFAULT_CONFIG = {UNSPECIFIED_P_ERROR,
UNSPECIFIED_GLOBAL_P_ERROR,
DEFAULT_DISPLAY,
DEFAULT_STRATEGY,
DEFAULT_KEY_SHARING,
DEFAULT_MULTI_PARAM_STRATEGY,
DEFAULT_SECURITY,
DEFAULT_FALLBACK_LOG_NORM_WOPPBS,
DEFAULT_USE_GPU_CONSTRAINTS,
DEFAULT_ENCODING,
DEFAULT_CACHE_ON_DISK,
DEFAULT_CIPHERTEXT_MODULUS_LOG,
DEFAULT_FFT_PRECISION,
DEFAULT_COMPOSITION_RULES,
DEFAULT_COMPOSABLE};

using Dag = rust::Box<concrete_optimizer::Dag>;
using DagBuilder = rust::Box<concrete_optimizer::DagBuilder>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
options.optimizerConfig.composition_rules.push_back(
{from_func, from_pos, to_func, to_pos});
})
.def("set_composable",
[](CompilationOptions &options, bool composable) {
options.optimizerConfig.composable = composable;
})
.def("set_security_level",
[](CompilationOptions &options, int security_level) {
options.optimizerConfig.security = security_level;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,19 @@ def add_composition(self, from_func: str, from_pos: int, to_func: str, to_pos: i
raise TypeError("expected `to_pos` to be (int)")
self.cpp().add_composition(from_func, from_pos, to_func, to_pos)

def set_composable(self, composable: bool):
"""Set composable flag.
Args:
composable(bool): the composable flag.
Raises:
TypeError: if the inputs do not have the proper type.
"""
if not isinstance(composable, bool):
raise TypeError("expected `composable` to be (bool)")
self.cpp().set_composable(composable)

def set_auto_parallelize(self, auto_parallelize: bool):
"""Set option for auto parallelization.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,11 @@ std::unique_ptr<mlir::Pass> createDagPass(optimizer::Config config,
// Adds the composition rules to the
void applyCompositionRules(optimizer::Config config,
concrete_optimizer::Dag &dag) {

if (config.composable) {
dag.add_all_compositions();
return;
}
for (auto rule : config.composition_rules) {
dag.add_composition(rule.from_func, rule.from_pos, rule.to_func,
rule.to_pos);
Expand Down
2 changes: 1 addition & 1 deletion docs/guides/configure.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ Additional kwargs to `compile` functions take higher precedence. So if you set t
* **shifts\_with\_promotion**: bool = True,
* Enable promotions in encrypted shifts instead of casting in runtime. See [Bitwise#Shifts](../core-features/bitwise.md#Shifts) to learn more.
* **composable**: bool = False,
* Specify that the function must be composable with itself.
* Specify that the function must be composable with itself. Only used when compiling a single circuit; when compiling modules use the [composition policy](../compilation/composing_functions_with_modules.md#optimizing_runtimes_with_composition_policies).
* **relu\_on\_bits\_threshold**: int = 7,
* Bit-width to start implementing the ReLU extension with [fhe.bits](../core-features/bit\_extraction.md).
* **relu\_on\_bits\_chunk\_size**: int = 3,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,7 @@ class Configuration:
dynamic_indexing_check_out_of_bounds: bool
dynamic_assignment_check_out_of_bounds: bool
simulate_encrypt_run_decrypt: bool
composable: bool

def __init__(
self,
Expand Down
39 changes: 36 additions & 3 deletions frontends/concrete-python/concrete/fhe/compilation/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

# pylint: disable=import-error,no-member,no-name-in-module

import json
import shutil
import tempfile
from pathlib import Path
Expand Down Expand Up @@ -35,7 +36,7 @@
from mlir.ir import Module as MlirModule

from ..internal.utils import assert_that
from .composition import CompositionRule
from .composition import CompositionClause, CompositionRule
from .configuration import (
DEFAULT_GLOBAL_P_ERROR,
DEFAULT_P_ERROR,
Expand Down Expand Up @@ -65,6 +66,7 @@ class Server:

_mlir: Optional[str]
_configuration: Optional[Configuration]
_composition_rules: Optional[List[CompositionRule]]

def __init__(
self,
Expand All @@ -74,6 +76,7 @@ def __init__(
compilation_result: LibraryCompilationResult,
server_program: ServerProgram,
is_simulated: bool,
composition_rules: Optional[List[CompositionRule]],
):
self.client_specs = client_specs
self.is_simulated = is_simulated
Expand All @@ -84,6 +87,7 @@ def __init__(
self._compilation_feedback = self._support.load_compilation_feedback(compilation_result)
self._server_program = server_program
self._mlir = None
self._composition_rules = composition_rules

assert_that(
support.load_client_parameters(compilation_result).serialize()
Expand Down Expand Up @@ -131,7 +135,9 @@ def create(
options.set_enable_overflow_detection_in_simulation(
configuration.detect_overflow_in_simulation
)
composition_rules = composition_rules if composition_rules else []

options.set_composable(configuration.composable)
composition_rules = list(composition_rules) if composition_rules else []
for rule in composition_rules:
options.add_composition(rule.from_.func, rule.from_.pos, rule.to.func, rule.to.pos)

Expand Down Expand Up @@ -219,6 +225,7 @@ def create(

client_parameters = support.load_client_parameters(compilation_result)
client_specs = ClientSpecs(client_parameters)
composition_rules = composition_rules if composition_rules else None

result = Server(
client_specs,
Expand All @@ -227,6 +234,7 @@ def create(
compilation_result,
server_program,
is_simulated,
composition_rules,
)

# pylint: disable=protected-access
Expand Down Expand Up @@ -268,6 +276,9 @@ def save(self, path: Union[str, Path], via_mlir: bool = False):
with open(Path(tmp) / "configuration.json", "w", encoding="utf-8") as f:
f.write(jsonpickle.dumps(self._configuration.__dict__))

with open(Path(tmp) / "composition_rules.json", "w", encoding="utf-8") as f:
f.write(json.dumps(self._composition_rules))

shutil.make_archive(path, "zip", tmp)

return
Expand All @@ -282,6 +293,9 @@ def save(self, path: Union[str, Path], via_mlir: bool = False):
with open(Path(self._output_dir) / "is_simulated", "w", encoding="utf-8") as f:
f.write("1" if self.is_simulated else "0")

with open(Path(self._output_dir) / "composition_rules.json", "w", encoding="utf-8") as f:
f.write(json.dumps(self._composition_rules))

shutil.make_archive(path, "zip", self._output_dir)

@staticmethod
Expand All @@ -308,14 +322,32 @@ def load(path: Union[str, Path]) -> "Server":
with open(output_dir_path / "is_simulated", "r", encoding="utf-8") as f:
is_simulated = f.read() == "1"

composition_rules = None
if (output_dir_path / "composition_rules.json").exists():
with open(output_dir_path / "composition_rules.json", "r", encoding="utf-8") as f:
composition_rules = json.loads(f.read())
composition_rules = (
[
CompositionRule(
CompositionClause(rule[0][0], rule[0][1]),
CompositionClause(rule[1][0], rule[1][1]),
)
for rule in composition_rules
]
if composition_rules
else None
)

if (output_dir_path / "circuit.mlir").exists():
with open(output_dir_path / "circuit.mlir", "r", encoding="utf-8") as f:
mlir = f.read()

with open(output_dir_path / "configuration.json", "r", encoding="utf-8") as f:
configuration = Configuration().fork(**jsonpickle.loads(f.read()))

return Server.create(mlir, configuration, is_simulated)
return Server.create(
mlir, configuration, is_simulated, composition_rules=composition_rules
)

with open(output_dir_path / "client.specs.json", "rb") as f:
client_specs = ClientSpecs.deserialize(f.read())
Expand All @@ -335,6 +367,7 @@ def load(path: Union[str, Path]) -> "Server":
compilation_result,
server_program,
is_simulated,
composition_rules,
)

def run(
Expand Down

0 comments on commit 13253d1

Please sign in to comment.