Source code for cantools.database.can.database

import logging
from collections import OrderedDict
from typing import (
    Any,
    Dict,
    List,
    Optional,
    TextIO,
    Union,
)

from ...typechecking import DecodeResultType, EncodeInputType, StringPathLike
from ..errors import DecodeError
from ..utils import (
    SORT_SIGNALS_DEFAULT,
    sort_signals_by_start_bit,
    type_sort_attributes,
    type_sort_choices,
    type_sort_signals,
)
from .bus import Bus
from .formats import arxml, dbc, kcd, sym
from .formats.arxml import AutosarDatabaseSpecifics
from .formats.dbc import DbcSpecifics
from .internal_database import InternalDatabase
from .message import Message
from .node import Node

LOGGER = logging.getLogger(__name__)


[docs]class Database: """This class contains all messages, signals and definitions of a CAN network. The factory functions :func:`load()<cantools.database.load()>`, :func:`load_file()<cantools.database.load_file()>` and :func:`load_string()<cantools.database.load_string()>` returns instances of this class. If `strict` is ``True`` an exception is raised if any signals are overlapping or if they don't fit in their message. By default signals are sorted by their start bit when their Message object is created. If you don't want them to be sorted pass `sort_signals = None`. If you want the signals to be sorted in another way pass something like `sort_signals = lambda signals: list(sorted(signals, key=lambda sig: sig.name))` """ def __init__(self, messages: Optional[List[Message]] = None, nodes: Optional[List[Node]] = None, buses: Optional[List[Bus]] = None, version: Optional[str] = None, dbc_specifics: Optional[DbcSpecifics] = None, autosar_specifics: Optional[AutosarDatabaseSpecifics] = None, frame_id_mask: Optional[int] = None, strict: bool = True, sort_signals: type_sort_signals = sort_signals_by_start_bit, ) -> None: self._messages = messages or [] self._nodes = nodes or [] self._buses = buses or [] self._name_to_message: Dict[str, Message] = {} self._frame_id_to_message: Dict[int, Message] = {} self._version = version self._dbc = dbc_specifics self._autosar = autosar_specifics if frame_id_mask is None: frame_id_mask = 0xffffffff self._frame_id_mask = frame_id_mask self._strict = strict self._sort_signals = sort_signals self.refresh() @property def messages(self) -> List[Message]: """A list of messages in the database. Use :meth:`.get_message_by_frame_id()` or :meth:`.get_message_by_name()` to find a message by its frame id or name. """ return self._messages @property def nodes(self) -> List[Node]: """A list of nodes in the database. """ return self._nodes @property def buses(self) -> List[Bus]: """A list of CAN buses in the database. """ return self._buses @property def version(self) -> Optional[str]: """The database version, or ``None`` if unavailable. """ return self._version @version.setter def version(self, value: Optional[str]) -> None: self._version = value @property def dbc(self) -> Optional[DbcSpecifics]: """An object containing dbc specific properties like e.g. attributes. """ return self._dbc @dbc.setter def dbc(self, value: Optional[DbcSpecifics]) -> None: self._dbc = value @property def autosar(self) -> Optional[AutosarDatabaseSpecifics]: """An object containing AUTOSAR specific properties like e.g. attributes. """ return self._autosar @autosar.setter def autosar(self, value: Optional[AutosarDatabaseSpecifics]) -> None: self._autosar = value
[docs] def is_similar(self, other: "Database", *, tolerance: float = 1e-12, include_format_specifics: bool = True) -> bool: """Compare two database objects inexactly This means that small discrepanceies stemming from e.g. rounding errors are ignored. """ return self._objects_similar(self, other, tolerance, include_format_specifics)
@staticmethod def _objects_similar(a: Any, b: Any, tolerance: float, include_format_specifics: bool) -> bool: if type(a) != type(b): # the types of the objects do not match return False elif a is None: # a and b are None return True elif isinstance(a, (int, str, set)): # the values of the objects must be equal return bool(a == b) elif isinstance(a, float): # floating point objects are be compared inexactly if abs(a) > 1: if abs(1.0 - b/a) > tolerance: return False else: if abs(b - a) > tolerance: return False return True elif isinstance(a, (list, tuple)): # lists and tuples are similar if all elements are similar for i in range(0, len(a)): if not Database._objects_similar(a[i], b[i], tolerance, include_format_specifics): return False return True elif isinstance(a, (dict, OrderedDict)): # dictionaries are similar if they feature the same keys and # all elements are similar if a.keys() != b.keys(): return False for key in a: if not Database._objects_similar(a[key], b[key], tolerance, include_format_specifics): return False return True # assume that `a` and `b` are objects of custom classes a_attrib_names = dir(a) b_attrib_names = dir(b) if not include_format_specifics: # ignore format specific attributes if requested. So far, # only DBC and ARXML amend the databaase with format # specific information. for x in 'dbc', 'autosar': if x in a_attrib_names: a_attrib_names.remove(x) if x in b_attrib_names: b_attrib_names.remove(x) # both objects must exhibit the same attributes and member functions if a_attrib_names != b_attrib_names: return False for attrib_name in a_attrib_names: if attrib_name.startswith('_'): # ignore non-public attributes continue a_attrib = getattr(a, attrib_name) b_attrib = getattr(b, attrib_name) if type(a_attrib) != type(b_attrib): return False elif callable(a_attrib): # ignore callable attributes continue elif not Database._objects_similar(a_attrib, b_attrib, tolerance, include_format_specifics): return False return True
[docs] def add_arxml(self, fp: TextIO) -> None: """Read and parse ARXML data from given file-like object and add the parsed data to the database. """ self.add_arxml_string(fp.read())
[docs] def add_arxml_file(self, filename: StringPathLike, encoding: str = 'utf-8') -> None: """Open, read and parse ARXML data from given file and add the parsed data to the database. `encoding` specifies the file encoding. """ with open(filename, encoding=encoding, errors='replace') as fin: self.add_arxml(fin)
[docs] def add_arxml_string(self, string: str) -> None: """Parse given ARXML data string and add the parsed data to the database. """ database = arxml.load_string(string, self._strict, sort_signals=self._sort_signals) self._messages += database.messages self._nodes = database.nodes self._buses = database.buses self._version = database.version self._dbc = database.dbc self._autosar = database.autosar self.refresh()
[docs] def add_dbc(self, fp: TextIO) -> None: """Read and parse DBC data from given file-like object and add the parsed data to the database. >>> db = cantools.database.Database() >>> with open ('foo.dbc', 'r') as fin: ... db.add_dbc(fin) """ self.add_dbc_string(fp.read())
[docs] def add_dbc_file(self, filename: StringPathLike, encoding: str = 'cp1252') -> None: """Open, read and parse DBC data from given file and add the parsed data to the database. `encoding` specifies the file encoding. >>> db = cantools.database.Database() >>> db.add_dbc_file('foo.dbc') """ with open(filename, encoding=encoding, errors='replace') as fin: self.add_dbc(fin)
[docs] def add_dbc_string(self, string: str) -> None: """Parse given DBC data string and add the parsed data to the database. >>> db = cantools.database.Database() >>> with open ('foo.dbc', 'r') as fin: ... db.add_dbc_string(fin.read()) """ database = dbc.load_string(string, self._strict, sort_signals=self._sort_signals) self._messages += database.messages self._nodes = database.nodes self._buses = database.buses self._version = database.version self._dbc = database.dbc self.refresh()
[docs] def add_kcd(self, fp: TextIO) -> None: """Read and parse KCD data from given file-like object and add the parsed data to the database. """ self.add_kcd_string(fp.read())
[docs] def add_kcd_file(self, filename: StringPathLike, encoding: str = 'utf-8') -> None: """Open, read and parse KCD data from given file and add the parsed data to the database. `encoding` specifies the file encoding. """ with open(filename, encoding=encoding, errors='replace') as fin: self.add_kcd(fin)
[docs] def add_kcd_string(self, string: str) -> None: """Parse given KCD data string and add the parsed data to the database. """ database = kcd.load_string(string, self._strict, sort_signals=self._sort_signals) self._messages += database.messages self._nodes = database.nodes self._buses = database.buses self._version = database.version self._dbc = database.dbc self.refresh()
[docs] def add_sym(self, fp: TextIO) -> None: """Read and parse SYM data from given file-like object and add the parsed data to the database. """ self.add_sym_string(fp.read())
[docs] def add_sym_file(self, filename: StringPathLike, encoding: str = 'utf-8') -> None: """Open, read and parse SYM data from given file and add the parsed data to the database. `encoding` specifies the file encoding. """ with open(filename, encoding=encoding, errors='replace') as fin: self.add_sym(fin)
[docs] def add_sym_string(self, string: str) -> None: """Parse given SYM data string and add the parsed data to the database. """ database = sym.load_string(string, self._strict, sort_signals=self._sort_signals) self._messages += database.messages self._nodes = database.nodes self._buses = database.buses self._version = database.version self._dbc = database.dbc self.refresh()
def _add_message(self, message: Message) -> None: """Add given message to the database. """ if message.name in self._name_to_message: LOGGER.warning("Overwriting message '%s' with '%s' in the " "name to message dictionary.", self._name_to_message[message.name].name, message.name) masked_frame_id = (message.frame_id & self._frame_id_mask) if masked_frame_id in self._frame_id_to_message: LOGGER.warning( "Overwriting message '%s' with '%s' in the frame id to message " "dictionary because they have identical masked frame ids 0x%x.", self._frame_id_to_message[masked_frame_id].name, message.name, masked_frame_id) self._name_to_message[message.name] = message self._frame_id_to_message[masked_frame_id] = message
[docs] def as_dbc_string(self, *, sort_signals:type_sort_signals=SORT_SIGNALS_DEFAULT, sort_attribute_signals:type_sort_signals=SORT_SIGNALS_DEFAULT, sort_attributes:type_sort_attributes=None, sort_choices:type_sort_choices=None, shorten_long_names:bool=True) -> str: """Return the database as a string formatted as a DBC file. sort_signals defines how to sort signals in message definitions sort_attribute_signals defines how to sort signals in metadata - comments, value table definitions and attributes """ if not self._sort_signals and sort_signals == SORT_SIGNALS_DEFAULT: sort_signals = None return dbc.dump_string(InternalDatabase(self._messages, self._nodes, self._buses, self._version, self._dbc), sort_signals=sort_signals, sort_attribute_signals=sort_attribute_signals, sort_attributes=sort_attributes, sort_choices=sort_choices, shorten_long_names=shorten_long_names)
[docs] def as_kcd_string(self, *, sort_signals:type_sort_signals=SORT_SIGNALS_DEFAULT) -> str: """Return the database as a string formatted as a KCD file. """ if not self._sort_signals and sort_signals == SORT_SIGNALS_DEFAULT: sort_signals = None return kcd.dump_string(InternalDatabase(self._messages, self._nodes, self._buses, self._version, self._dbc), sort_signals=sort_signals)
[docs] def as_sym_string(self, *, sort_signals:type_sort_signals=SORT_SIGNALS_DEFAULT) -> str: """Return the database as a string formatted as a SYM file. """ if not self._sort_signals and sort_signals == SORT_SIGNALS_DEFAULT: sort_signals = None return sym.dump_string(InternalDatabase(self._messages, self._nodes, self._buses, self._version, self._dbc), sort_signals=sort_signals)
[docs] def get_message_by_name(self, name: str) -> Message: """Find the message object for given name `name`. """ return self._name_to_message[name]
[docs] def get_message_by_frame_id(self, frame_id: int) -> Message: """Find the message object for given frame id `frame_id`. """ return self._frame_id_to_message[frame_id & self._frame_id_mask]
[docs] def get_node_by_name(self, name: str) -> Node: """Find the node object for given name `name`. """ for node in self._nodes: if node.name == name: return node raise KeyError(name)
[docs] def get_bus_by_name(self, name: str) -> Bus: """Find the bus object for given name `name`. """ for bus in self._buses: if bus.name == name: return bus raise KeyError(name)
[docs] def encode_message(self, frame_id_or_name: Union[int, str], data: EncodeInputType, scaling: bool = True, padding: bool = False, strict: bool = True, ) -> bytes: """Encode given signal data `data` as a message of given frame id or name `frame_id_or_name`. For regular Messages, `data` is a dictionary of signal name-value entries, for container messages it is a list of (ContainedMessageOrMessageName, ContainedMessageSignals) tuples. If `scaling` is ``False`` no scaling of signals is performed. If `padding` is ``True`` unused bits are encoded as 1. If `strict` is ``True`` all signal values must be within their allowed ranges, or an exception is raised. >>> db.encode_message(158, {'Bar': 1, 'Fum': 5.0}) b'\\x01\\x45\\x23\\x00\\x11' >>> db.encode_message('Foo', {'Bar': 1, 'Fum': 5.0}) b'\\x01\\x45\\x23\\x00\\x11' """ if isinstance(frame_id_or_name, int): message = self._frame_id_to_message[frame_id_or_name] elif isinstance(frame_id_or_name, str): message = self._name_to_message[frame_id_or_name] else: raise ValueError(f"Invalid frame_id_or_name '{frame_id_or_name}'") return message.encode(data, scaling, padding, strict)
[docs] def decode_message(self, frame_id_or_name: Union[int, str], data: bytes, decode_choices: bool = True, scaling: bool = True, decode_containers: bool = False, allow_truncated: bool = False ) \ -> DecodeResultType: """Decode given signal data `data` as a message of given frame id or name `frame_id_or_name`. Returns a dictionary of signal name-value entries. If `decode_choices` is ``False`` scaled values are not converted to choice strings (if available). If `scaling` is ``False`` no scaling of signals is performed. >>> db.decode_message(158, b'\\x01\\x45\\x23\\x00\\x11') {'Bar': 1, 'Fum': 5.0} >>> db.decode_message('Foo', b'\\x01\\x45\\x23\\x00\\x11') {'Bar': 1, 'Fum': 5.0} If `decode_containers` is ``True``, container frames are decoded. The reason why this needs to be explicitly enabled is that decoding container frames returns a list of ``(Message, SignalsDict)`` tuples which will cause code that does not expect this to misbehave. Trying to decode a container message with `decode_containers` set to ``False`` will raise a `DecodeError`. """ if isinstance(frame_id_or_name, int): message = self._frame_id_to_message[frame_id_or_name] elif isinstance(frame_id_or_name, str): message = self._name_to_message[frame_id_or_name] else: raise ValueError(f"Invalid frame_id_or_name '{frame_id_or_name}'") if message.is_container: if decode_containers: return message.decode(data, decode_choices, scaling, decode_containers=True, allow_truncated=allow_truncated) else: raise DecodeError(f'Message "{message.name}" is a container ' f'message, but decoding such messages has ' f'not been enabled!') return message.decode(data, decode_choices, scaling, allow_truncated=allow_truncated)
[docs] def refresh(self) -> None: """Refresh the internal database state. This method must be called after modifying any message in the database to refresh the internal lookup tables used when encoding and decoding messages. """ self._name_to_message = {} self._frame_id_to_message = {} for message in self._messages: message.refresh(self._strict) self._add_message(message)
def __repr__(self) -> str: lines = [f"version('{self._version}')", ''] if self._nodes: for node in self._nodes: lines.append(repr(node)) lines.append('') for message in self._messages: lines.append(repr(message)) for signal in message.signals: lines.append(' ' + repr(signal)) lines.append('') return '\n'.join(lines)