#!/usr/bin/env python
# ******************************************************************************
# Copyright 2025 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
__all__ = ["convert", "print_report"]
import onnx
import onnx_ir as ir
from onnx_ir.passes.common import NameFixPass
from quantizeml.onnx_support.quantization.shape import set_model_shape
from quantizeml.onnx_support.quantization.transforms import sanitize
from .compatibility_info import ModelCompatibilityInfo
from .pipeline import quantize, convert_to_hybrid
from .tools import ensure_model_type, convert_model_to, ONNXExtractorModel
from .hybrid_model import HybridModel
def _match_qnode_on_model(qnode, qmodel, float_model):
# Search node input names.
input_names = []
for node in qmodel.get_parents(qnode):
if node.op_type == "InputQuantizer":
# Rename qnode.input to match with a tensor in float_model.
input_names.append(node.input[0])
else:
input_names.append(node.output[0])
# Search node output names.
output_name = qnode.output[0]
if len(children := qmodel.get_children(qnode)) == 1 and children[0].op_type == "Dequantizer":
# Rename qnode.output to match with model.
output_name = children[0].output[0]
# Search all nodes in float_model between inputs and outputs.
return float_model.extractor._collect_reachable_nodes(input_names, [output_name])
[docs]
@ensure_model_type
def convert(model,
input_shape=None,
input_dtype="uint8",
samples=None,
num_samples=1,
device=None,
enable_hwpr=False,
sram_size=None,
minimal_memory=False):
"""Check ONNX model compatibility with Akida and convert the model into a HybridModel.
Args:
model (onnx.ModelProto): The ONNX model.
input_shape (Iterable, optional): An iterable specifying the new model input shape
excluding batch dimension. Defaults to None.
input_dtype (np.dtype or str, optional): expected model input format. If given as a string,
should follow numpy string type requirements. Defaults to 'uint8'.
samples (list of numpy arrays, optional): List of input samples to use for
calibration. If not provided, random samples will be generated. Defaults
to None.
num_samples (int, optional): Number of samples to use for calibration.
Defaults to 1.
device (akida.Device, optional): the Akida device to map the Akida sub models.
Defaults to None.
enable_hwpr (bool, optional): if True, the device is computed assuming partial
reconfiguration. Used when `device` is None. Defaults to False.
sram_size (akida.NP.SramSize, optional): Size of shared SRAM available inside the mesh.
Ignored when `minimal_memory` is True. Used when `device` is None. Defaults to None.
minimal_memory (bool, optional): if True, computes and sets the minimal required
inputs and weights memory for the device. Used when `device` is None.
Defaults to False.
Returns:
HybridModel, ModelCompatibilityInfo: converted model and
object containing information about model compatibility.
"""
# Check model validity with full_check=False to avoid running shape inference check. It will be
# done later at the end of set_model_shape.
try:
onnx.checker.check_model(model.model, full_check=False)
except onnx.checker.ValidationError as e:
raise ValueError(f"Invalid ONNX model: {e}")
# Sanitize the model, since _check_akida_compatibility required it.
model = set_model_shape(model.clone(), input_shape=input_shape)
sanitized_model = sanitize(model)
# Convert to extractor model to handle compatibility info.
sanitized_model = convert_model_to(sanitized_model, new_type=ONNXExtractorModel)
# Generates node names to better understand incompatibilities.
# Use onnx_ir NameFixPass to fix when there are duplicates
ir_pass = NameFixPass()
sanitized_model.model = ir.to_proto(ir_pass(ir.from_proto(sanitized_model.model)).model)
# Quantize model.
qmodel, q_compatibility_info = quantize(sanitized_model, input_dtype=input_dtype,
samples=samples, num_samples=num_samples)
# Convert model.
hybrid_model, ak_compatibility_info = convert_to_hybrid(qmodel,
device=device,
enable_hwpr=enable_hwpr,
sram_size=sram_size,
minimal_memory=minimal_memory)
# Merge compatibilities into just one.
compatibility_info = ModelCompatibilityInfo(sanitized_model.model)
for q_incompatibility in q_compatibility_info.incompatible_sequences:
compatibility_info._set_incompatibility(node_sequence=q_incompatibility.nodes,
stage=q_incompatibility.stage,
faulty_node=q_incompatibility.faulty_node,
reason=q_incompatibility.reason)
for c_incompatibility in ak_compatibility_info.incompatible_sequences:
# Match node sequence in float model before to set in general info.
node_sequence = [_match_qnode_on_model(qnode, qmodel, sanitized_model)
for qnode in c_incompatibility.nodes]
# Transfer incompatibility to general info.
compatibility_info._set_incompatibility(node_sequence=sum(node_sequence, []),
stage=c_incompatibility.stage,
faulty_node=c_incompatibility.faulty_node,
reason=c_incompatibility.reason)
return hybrid_model, compatibility_info
[docs]
def print_report(hybrid_model, compatibility_info):
"""Prints a report of the model compatibility with Akida.
Args:
hybrid_model (HybridModel): The converted hybrid model.
compatibility_info (ModelCompatibilityInfo): The compatibility information.
"""
# Validate argument types for backward compatibility with versions < 0.7.0
assert isinstance(hybrid_model, HybridModel) and \
isinstance(compatibility_info, ModelCompatibilityInfo), "Invalid arguments types."
# Color codes
RESET = "\033[0m"
YELLOW = "\033[33m"
CYAN = "\033[36m"
GREEN = "\033[1;32m"
incompatibilities = compatibility_info.incompatibilities
if incompatibilities:
lines = [
f"\nSet of incompatible op_types: {YELLOW}"
f"{compatibility_info.incompatible_op_types}{RESET}",
"List of incompatibilities:",
]
for incompatibility in incompatibilities:
seq_desc = ", ".join(
f"{n['name']}({YELLOW}op_type={n['op_type']}{RESET})"
for n in incompatibility["node_sequence"]
)
lines.append(f" ❌ Node sequence: [{seq_desc}]")
lines.append(f" • {CYAN}Stage{RESET}: {incompatibility['stage']}")
lines.append(f" • {CYAN}Faulty node{RESET}: {incompatibility['faulty_node']}")
lines.append(f" • {CYAN}Reason{RESET}: {incompatibility['reason']}\n")
print("\n".join(lines))
print(
f"[INFO]: Percentage of nodes compatible with akida: "
f"{GREEN}{compatibility_info.compatibility_percentage:.4f} %{RESET}\n"
)
print(
f"[INFO]: Number of mappable sequences on akida: "
f"{GREEN}{len(hybrid_model.akida_models)}{RESET}"
)
if data_movement := hybrid_model.compute_data_movement():
print("\nList of backends exchanges:")
for data in data_movement:
size_kb = data['size'] / 1024
print(f" • {CYAN}{data['type']}{RESET} at layer "
f"{YELLOW}{data['layer'].name}{RESET}: {GREEN}{size_kb:.3f} KB{RESET}")
print()