# Copyright 2017 D-Wave Systems Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module has the primary public-facing methods for the project.
"""
from six import iteritems
import penaltymodel.core as pm
from penaltymodel.cache.database_manager import cache_connect, insert_penalty_model, \
iter_penalty_model_from_specification
__all__ = ['get_penalty_model',
'cache_penalty_model']
[docs]@pm.interface.penaltymodel_factory(100)
def get_penalty_model(specification, database=None):
"""Factory function for penaltymodel_cache.
Args:
specification (penaltymodel.Specification): The specification
for the desired penalty model.
database (str, optional): The path to the desired sqlite database
file. If None, will use the default.
Returns:
:class:`penaltymodel.PenaltyModel`: Penalty model with the given specification.
Raises:
:class:`penaltymodel.MissingPenaltyModel`: If the penalty model is not in the
cache.
Parameters:
priority (int): 100
"""
# only handles index-labelled nodes
if not _is_index_labelled(specification.graph):
relabel_applied = True
mapping, inverse_mapping = _graph_canonicalization(specification.graph)
specification = specification.relabel_variables(mapping, inplace=False)
else:
relabel_applied = False
# connect to the database. Note that once the connection is made it cannot be
# broken up between several processes.
if database is None:
conn = cache_connect()
else:
conn = cache_connect(database)
# get the penalty_model
with conn as cur:
try:
widget = next(iter_penalty_model_from_specification(cur, specification))
except StopIteration:
widget = None
# close the connection
conn.close()
if widget is None:
raise pm.MissingPenaltyModel("no penalty model with the given specification found in cache")
if relabel_applied:
# relabel the widget in-place
widget.relabel_variables(inverse_mapping, inplace=True)
return widget
[docs]def cache_penalty_model(penalty_model, database=None):
"""Caching function for penaltymodel_cache.
Args:
penalty_model (:class:`penaltymodel.PenaltyModel`): Penalty model to
be cached.
database (str, optional): The path to the desired sqlite database
file. If None, will use the default.
"""
# only handles index-labelled nodes
if not _is_index_labelled(penalty_model.graph):
mapping, __ = _graph_canonicalization(penalty_model.graph)
penalty_model = penalty_model.relabel_variables(mapping, inplace=False)
# connect to the database. Note that once the connection is made it cannot be
# broken up between several processes.
if database is None:
conn = cache_connect()
else:
conn = cache_connect(database)
# load into the database
with conn as cur:
insert_penalty_model(cur, penalty_model)
# close the connection
conn.close()
def _is_index_labelled(graph):
"""graph is index-labels [0, len(graph) - 1]"""
return all(v in graph for v in range(len(graph)))
def _graph_canonicalization(graph):
try:
inverse_mapping = dict(enumerate(sorted(graph)))
except TypeError:
inverse_mapping = dict(enumerate(graph))
mapping = {v: idx for idx, v in iteritems(inverse_mapping)}
return mapping, inverse_mapping