[go: nahoru, domu]

Skip to content

Commit

Permalink
[shape_poly] Shape polymorphism support for approx_top_k
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 543633818
  • Loading branch information
gnecula authored and tensorflower-gardener committed Jun 27, 2023
1 parent ab87e62 commit 9e74ca4
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tensorflow/compiler/xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
_version = 165

# Version number for MLIR:Python components.
mlir_api_version = 50
mlir_api_version = 51

xla_platform_names = {
'cpu': 'Host',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1675,6 +1675,7 @@ LogicalResult ExportXlaOp(CustomCallOp op, OpLoweringContext ctx) {
// The attributes supported by the ApproxTopK custom_call are:
//
// - called_computation : This indicates the comparator for scoring entries
// - has_side_effect: always False
// - api_version : always 4, the typed FFI API
// - backend_config : The actual arguments to ApproxTopK. This includes
// + top_k:i64 : the number of results to return
Expand Down Expand Up @@ -1738,7 +1739,8 @@ LogicalResult ExportXlaOp(CustomCallOp op, OpLoweringContext ctx) {
auto isSupportedAttrName = [](NamedAttribute attr) {
auto name = attr.getName();
return name == "call_target_name" || name == "backend_config" ||
name == "api_version" || name == "called_computations";
name == "api_version" || name == "called_computations" ||
name == "has_side_effect";
};
for (const auto& attr : op->getAttrs()) {
if (!isSupportedAttrName(attr))
Expand Down

0 comments on commit 9e74ca4

Please sign in to comment.