Source code for onnx2akida.convert

#!/usr/bin/env python
# ******************************************************************************
# Copyright 2025 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************

__all__ = ["convert", "print_report"]

import onnx
import onnx_ir as ir
from onnx_ir.passes.common import NameFixPass

from quantizeml.onnx_support.quantization.shape import set_model_shape
from quantizeml.onnx_support.quantization.transforms import sanitize

from .compatibility_info import ModelCompatibilityInfo
from .pipeline import quantize, convert_to_hybrid
from .tools import ensure_model_type, convert_model_to, ONNXExtractorModel
from .hybrid_model import HybridModel


def _match_qnode_on_model(qnode, qmodel, float_model):
    # Search node input names.
    input_names = []
    for node in qmodel.get_parents(qnode):
        if node.op_type == "InputQuantizer":
            # Rename qnode.input to match with a tensor in float_model.
            input_names.append(node.input[0])
        else:
            input_names.append(node.output[0])

    # Search node output names.
    output_name = qnode.output[0]
    if len(children := qmodel.get_children(qnode)) == 1 and children[0].op_type == "Dequantizer":
        # Rename qnode.output to match with model.
        output_name = children[0].output[0]

    # Search all nodes in float_model between inputs and outputs.
    return float_model.extractor._collect_reachable_nodes(input_names, [output_name])


[docs] @ensure_model_type def convert(model, input_shape=None, input_dtype="uint8", samples=None, num_samples=1, device=None, enable_hwpr=False, sram_size=None, minimal_memory=False): """Check ONNX model compatibility with Akida and convert the model into a HybridModel. Args: model (onnx.ModelProto): The ONNX model. input_shape (Iterable, optional): An iterable specifying the new model input shape excluding batch dimension. Defaults to None. input_dtype (np.dtype or str, optional): expected model input format. If given as a string, should follow numpy string type requirements. Defaults to 'uint8'. samples (list of numpy arrays, optional): List of input samples to use for calibration. If not provided, random samples will be generated. Defaults to None. num_samples (int, optional): Number of samples to use for calibration. Defaults to 1. device (akida.Device, optional): the Akida device to map the Akida sub models. Defaults to None. enable_hwpr (bool, optional): if True, the device is computed assuming partial reconfiguration. Used when `device` is None. Defaults to False. sram_size (akida.NP.SramSize, optional): Size of shared SRAM available inside the mesh. Ignored when `minimal_memory` is True. Used when `device` is None. Defaults to None. minimal_memory (bool, optional): if True, computes and sets the minimal required inputs and weights memory for the device. Used when `device` is None. Defaults to False. Returns: HybridModel, ModelCompatibilityInfo: converted model and object containing information about model compatibility. """ # Check model validity with full_check=False to avoid running shape inference check. It will be # done later at the end of set_model_shape. try: onnx.checker.check_model(model.model, full_check=False) except onnx.checker.ValidationError as e: raise ValueError(f"Invalid ONNX model: {e}") # Sanitize the model, since _check_akida_compatibility required it. model = set_model_shape(model.clone(), input_shape=input_shape) sanitized_model = sanitize(model) # Convert to extractor model to handle compatibility info. sanitized_model = convert_model_to(sanitized_model, new_type=ONNXExtractorModel) # Generates node names to better understand incompatibilities. # Use onnx_ir NameFixPass to fix when there are duplicates ir_pass = NameFixPass() sanitized_model.model = ir.to_proto(ir_pass(ir.from_proto(sanitized_model.model)).model) # Quantize model. qmodel, q_compatibility_info = quantize(sanitized_model, input_dtype=input_dtype, samples=samples, num_samples=num_samples) # Convert model. hybrid_model, ak_compatibility_info = convert_to_hybrid(qmodel, device=device, enable_hwpr=enable_hwpr, sram_size=sram_size, minimal_memory=minimal_memory) # Merge compatibilities into just one. compatibility_info = ModelCompatibilityInfo(sanitized_model.model) for q_incompatibility in q_compatibility_info.incompatible_sequences: compatibility_info._set_incompatibility(node_sequence=q_incompatibility.nodes, stage=q_incompatibility.stage, faulty_node=q_incompatibility.faulty_node, reason=q_incompatibility.reason) for c_incompatibility in ak_compatibility_info.incompatible_sequences: # Match node sequence in float model before to set in general info. node_sequence = [_match_qnode_on_model(qnode, qmodel, sanitized_model) for qnode in c_incompatibility.nodes] # Transfer incompatibility to general info. compatibility_info._set_incompatibility(node_sequence=sum(node_sequence, []), stage=c_incompatibility.stage, faulty_node=c_incompatibility.faulty_node, reason=c_incompatibility.reason) return hybrid_model, compatibility_info