-
Notifications
You must be signed in to change notification settings - Fork 3.6k
/
__init__.py
144 lines (102 loc) · 4.05 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
__all__ = [
"C",
"ONNX_DOMAIN",
"ONNX_ML_DOMAIN",
"AI_ONNX_PREVIEW_TRAINING_DOMAIN",
"has",
"register_schema",
"deregister_schema",
"get_schema",
"get_all_schemas",
"get_all_schemas_with_history",
"onnx_opset_version",
"get_function_ops",
"OpSchema",
"SchemaError",
]
import onnx.onnx_cpp2py_export.defs as C # noqa: N812
from onnx import AttributeProto, FunctionProto
ONNX_DOMAIN = ""
ONNX_ML_DOMAIN = "ai.onnx.ml"
AI_ONNX_PREVIEW_TRAINING_DOMAIN = "ai.onnx.preview.training"
has = C.has_schema
get_schema = C.get_schema
get_all_schemas = C.get_all_schemas
get_all_schemas_with_history = C.get_all_schemas_with_history
deregister_schema = C.deregister_schema
def onnx_opset_version() -> int:
"""Return current opset for domain `ai.onnx`."""
return C.schema_version_map()[ONNX_DOMAIN][1]
def onnx_ml_opset_version() -> int:
"""Return current opset for domain `ai.onnx.ml`."""
return C.schema_version_map()[ONNX_ML_DOMAIN][1]
@property # type: ignore
def _function_proto(self): # type: ignore
func_proto = FunctionProto()
func_proto.ParseFromString(self._function_body)
return func_proto
OpSchema = C.OpSchema # type: ignore
OpSchema.function_body = _function_proto # type: ignore
@property # type: ignore
def _attribute_default_value(self): # type: ignore
attr = AttributeProto()
attr.ParseFromString(self._default_value)
return attr
OpSchema.Attribute.default_value = _attribute_default_value # type: ignore
def _op_schema_repr(self) -> str:
return f"""\
OpSchema(
name={self.name!r},
domain={self.domain!r},
since_version={self.since_version!r},
doc={self.doc!r},
type_constraints={self.type_constraints!r},
inputs={self.inputs!r},
outputs={self.outputs!r},
attributes={self.attributes!r}
)"""
OpSchema.__repr__ = _op_schema_repr # type: ignore
def _op_schema_formal_parameter_repr(self) -> str:
return (
f"OpSchema.FormalParameter(name={self.name!r}, type_str={self.type_str!r}, "
f"description={self.description!r}, param_option={self.option!r}, "
f"is_homogeneous={self.is_homogeneous!r}, min_arity={self.min_arity!r}, "
f"differentiation_category={self.differentiation_category!r})"
)
OpSchema.FormalParameter.__repr__ = _op_schema_formal_parameter_repr # type: ignore
def _op_schema_type_constraint_param_repr(self) -> str:
return (
f"OpSchema.TypeConstraintParam(type_param_str={self.type_param_str!r}, "
f"allowed_type_strs={self.allowed_type_strs!r}, description={self.description!r})"
)
OpSchema.TypeConstraintParam.__repr__ = _op_schema_type_constraint_param_repr # type: ignore
def _op_schema_attribute_repr(self) -> str:
return (
f"OpSchema.Attribute(name={self.name!r}, type={self.type!r}, description={self.description!r}, "
f"default_value={self.default_value!r}, required={self.required!r})"
)
OpSchema.Attribute.__repr__ = _op_schema_attribute_repr # type: ignore
def get_function_ops() -> list[OpSchema]:
"""Return operators defined as functions."""
schemas = C.get_all_schemas()
return [schema for schema in schemas if schema.has_function or schema.has_context_dependent_function] # type: ignore
SchemaError = C.SchemaError
def register_schema(schema: OpSchema) -> None:
"""Register a user provided OpSchema.
The function extends available operator set versions for the provided domain if necessary.
Args:
schema: The OpSchema to register.
"""
version_map = C.schema_version_map()
domain = schema.domain
version = schema.since_version
min_version, max_version = version_map.get(domain, (version, version))
if domain not in version_map or not (min_version <= version <= max_version):
min_version = min(min_version, version)
max_version = max(max_version, version)
C.set_domain_to_version(schema.domain, min_version, max_version)
C.register_schema(schema)