-
Notifications
You must be signed in to change notification settings - Fork 0
/
ragged.pyi
28 lines (26 loc) · 982 Bytes
/
ragged.pyi
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from typing import Any, Callable, Sequence
import tensorflow as tf
from tensorflow import RaggedTensor, ScalarTensorCompatible, TensorCompatible
def stack(
values: Sequence[TensorCompatible | RaggedTensor], axis: ScalarTensorCompatible = 0, name: str | None = None
) -> RaggedTensor: ...
def constant(
pylist: TensorCompatible,
dtype: tf.DType | None = None,
ragged_rank: int | None = None,
inner_shape: tuple[int, ...] | None = None,
name: str | None = None,
row_splits_dtype: tf.DType = tf.dtypes.int64,
) -> RaggedTensor: ...
def map_flat_values(
op: Callable[..., tf.Tensor], *args: tf.RaggedTensor, **kwargs: tf.RaggedTensor
) -> RaggedTensor: ...
def range(
starts: TensorCompatible,
limits: TensorCompatible | None = None,
deltas: TensorCompatible = 1,
dtype: tf.DType | None = None,
name: str | None = None,
row_splits_dtype: tf.DType = tf.dtypes.int64,
) -> RaggedTensor: ...
def __getattr__(name: str) -> Any: ...