[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerical instability of gradient calculation of tf.norm (nan at 0, inf for small values) #12071

Open
oduerr opened this issue Aug 7, 2017 · 29 comments
Assignees
Labels
stat:contribution welcome Status - Contributions welcome

Comments

@oduerr
Copy link
oduerr commented Aug 7, 2017

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes see below
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Mac OS X 10.11.6
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): #v1.2.0-5-g435cdfc 1.2.1
  • Python version: 3.6
  • Bazel version (if compiling from source):
  • CUDA/cuDNN version: On CPU
  • GPU model and memory:
  • Exact command to reproduce: tf.norm at [0,0] see below for code
import numpy as np
import tensorflow as tf
print(tf.GIT_VERSION, "  ", tf.VERSION) #v1.2.0-5-g435cdfc    1.2.1

X = tf.placeholder(tf.float32, shape=(4,None))
Z = tf.norm(X, ord='euclidean', axis=1, name='logit')
var_grad = tf.gradients(Z, [X])

with tf.Session() as sess:
    X_ = np.array([
        [1],  # Grad OK
        [0],  # Grad NaN
        [1e-16],  # Grad OK
        [1e-19] #Grad Inf
    ], dtype=np.float32)
    sess.run(tf.global_variables_initializer())
    print(sess.run((Z, var_grad), feed_dict={X: X_}))
    # Result:
    #(array([9.99999940e-01, 0.00000000e+00, 9.99999951e-17,
    #        0.00000000e+00], dtype=float32), [array([[1.00000012],
    #                                                 [nan],
    #                                                 [1.],
    #                                                 [inf]], dtype=float32)])

Describe the problem

nan is calculated for the gradient of tf.norm at zero values. For extremely small values inf is calculated. Note that the exact result should be 1 in all cases above.

Above is a minimal example to reproduce it. The problem occurred in a real world scenario, when implementing a custom loss function (the entropy in https://arxiv.org/abs/1611.01449) and two embeddings where too close to each other (distance practically 0).

Source code / logs

See above

Output of logfile

== cat /etc/issue ===============================================
Darwin Olivers-MBP-5.fritz.box 15.6.0 Darwin Kernel Version 15.6.0: Tue Apr 11 16:00:51 PDT 2017; root:xnu-3248.60.11.5.3~1/RELEASE_X86_64 x86_64
Mac OS X 10.11.6

== are we in docker =========================================  echo == are we in docker ====================================num echo == are we in docker =========================================  ec==  echo == are we in docker =======================================c++ --version

== uname -a =====================================================
Darwin Olivers-MBP-5.fritz.box 15.6.0 Darwin Kernel Version 15.6.0: Tue Apr 11 16:00:51 PDT 2017; root:xnu-3248.60.11.5.3~1/RELEASE_X86_64 x86_64

== check pips ===================================================
numpy (1.13.0)
protobuf (3.3.0)
tensorflow (1.2.1)

== check for virtualenv ==============  echo == check for virtualenv =====on_b echo == check fo sys  echo == check for virtualenv ============== echo == check for virtualenv ============================================

== cat /etc/issue ===============================================
Darwin Olivers-MBP-5.fritz.box 15.6.0 Darwin Kernel Version 15.6.0: Tue Apr 11 16:00:51 PDT 2017; root:xnu-3248.60.11.5.3~1/RELEASE_X86_64 x86_64
Mac OS X 10.11.6

== are we in docker =============================================
No

== compiler =====================================================
Apple LLVM version 7.3.0 (clang-703.0.31)
Target: x86_64-apple-darwin15.6.0
Thread model: posix
InstalledDir: /Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin

== uname -a =====================================================
Darwin Olivers-MBP-5.fritz.box 15.6.0 Darwin Kernel Version 15.6.0: Tue Apr 11 16:00:51 PDT 2017; root:xnu-3248.60.11.5.3~1/RELEASE_X86_64 x86_64

== check pips ===================================================
numpy (1.13.0)
protobuf (3.3.0)
tensorflow (1.2.1)

== check for virtualenv =========================================
True

== tensorflow import ============================================
tf.VERSION = 1.2.1
tf.GIT_VERSION = v1.2.0-5-g435cdfc
tf.COMPILER_VERSION = v1.2.0-5-g435cdfc
Sanity check: array([1], dtype=int32)

== env ==========================================================
LD_LIBRARY_PATH is unset
DYLD_LIBRARY_PATH is unset

== nvidia-smi ===================================================
tf_env_collect.sh.txt: line 105: nvidia-smi: command not found

== cuda libs  ===================================================
@oduerr oduerr changed the title Numerical instability in tf.norm (nan at 0, inf for small values) Numerical instability of gradient calculation of tf.norm (nan at 0, inf for small values) Aug 7, 2017
@yaroslavvb
Copy link
Contributor
yaroslavvb commented Aug 7, 2017

This is caused by square root in definition of tf.norm. IE you are taking gradient of sqrt(x^2). Gradient of sqrt approaches infinity, whereas gradient of x^2 approaches 0, so computing them separately then multiplying is a problem. @goodfeli -- do you remember if Theano uses some standard stabilizing transformation for this kind of case?

@goodfeli
Copy link
goodfeli commented Aug 7, 2017

Theano gives NaN as the gradient of the norm of a vector with zero norm:

>>> x = theano.tensor.vector()
>>> y = theano.tensor.square(x)
>>> z = y.sum()
>>> norm = theano.tensor.sqrt(z)
>>> d = theano.tensor.grad(norm, x)
>>> d.eval({x: [0., 0.]})
array([ nan,  nan])

Theano does give 0 as the gradient of the norm of a scalar. I think for this it is probably using a patternsub to turn sqrt(square(x)) into abs(x). Note that in Theano's conventions, the derivative of abs(x) at 0 is treated as 0 rather than undefined.

@yaroslavvb
Copy link
Contributor
yaroslavvb commented Aug 7, 2017

Thinking philosophically, the general problem is that computational graph ends up with things likea/a where a is 0. Numerically it's undefined, but the limit exists. Similar issue exists with gradient of tf.select (#2540) and gradient of tf.exp(-tf.exp(x))

You have to do some algebraic massaging to get numerically defined result.

In your particular case, you could replace automatic gradient with a stable version:

from tensorflow.python.framework import function
import numpy as np
import tensorflow as tf

@function.Defun(tf.float32, tf.float32)
def norm_grad(x, dy):
    return dy*(x/tf.norm(x))

@function.Defun(tf.float32, grad_func=norm_grad)
def norm(x):
    return tf.norm(x)

sess = tf.InteractiveSession()
X = tf.placeholder(tf.float32, shape=(4,None))
X_ = np.array([
    [1],  # Grad OK
    [0],  # Grad NaN
    [1e-16],  # Grad OK
    [1e-19] #Grad Inf
], dtype=np.float32)
Z = norm(X)
var_grad = tf.gradients(Z, [X])
print(sess.run((Z, var_grad), feed_dict={X: X_}))

#1.0, [array([[  1.00000000e+00],
#      [  0.00000000e+00],
#     [  1.00000002e-16],
#    [  9.99999968e-20]], dtype=float32)])


@oduerr
Copy link
Author
oduerr commented Aug 8, 2017

First @yaroslavvb thanks a lot for code for the stable version of the gradient (I did not know that it's possible to overwrite the gradient calculation).

I agree that the corner case 0 is not really defined but the same is true for the gradient of a ReLU at 0. For float32 exact 0 is not a singular event, you might get it by rounding errors and then the whole optimization is broken. Furthermore (even more problematic) there are instabilities in the gradient of tf.norm for small values, e.g the Inf for e-19, which you could encounter in an optimization process containing the norm. In my case part of the loss function was exp(-norm(x-y)) (it took quite some time to nail down the problem to the gradient of the tf.norm()). Therefore, I think this is a real bug and not just a pure mathematical problem and deserves fixing.

I did not include this in my original report, but the problem is also present for non-scalars. See

X = tf.placeholder(tf.float32, shape=(1,3))
Z = tf.norm(X, ord='euclidean', axis=1, name='logit')
pik = tf.nn.softmax(logits=Z)
res = tf.reduce_sum(pik)
var_grad = tf.gradients(res, [X])

with tf.Session() as sess:
    X_ = np.array([
        [1e-19, 1e-19, 0]
    ], dtype=np.float32)

    sess.run(tf.global_variables_initializer())
    print(sess.run((res, var_grad), feed_dict={X:X_}))

which gives (nan, nan, nan).

@yaroslavvb
Copy link
Contributor
yaroslavvb commented Aug 8, 2017

Yes, I agree that this should be fixed, the question is how. The bug with tf.select-gradient-caused NaN's linked above is still open a year later and it's a similar problem.

Python-specific solution would be to replace norm with a fused version like above, and corresponding gradient. Fused versions are not really preferred but are done if it's important enough, like was done for tf.fused_batch_norm.

More general solution is to have a numerically stabilizing transformation which simplifies expressions that lead to "0/0" or "infinity*0" to behave more like their limits. @benoitsteiner @petewarden -- do you know if there's anything on numerically stabilizing graph transformations on the roadmap?

@tatatodd tatatodd added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Aug 9, 2017
@olegklimov
Copy link

+1

Already several times I've been hitting that.

@pandaczm
Copy link
pandaczm commented Nov 6, 2017

@yaroslavvb Hi, I don't understand why ## dy*(x/tf.norm(x)) ## this could lead to a stable gradient. Could you please explain it more? Thank you.

@shashankneocfc
Copy link
shashankneocfc commented Nov 10, 2017

+1 I'm also facing this issue while computing eucledian norm for the paper : A Structured Self-attentive Sentence Embedding

Edit : Pytorch applied a fix for it. May be tf can do the same https://github.com/pytorch/pytorch/pull/2775/files

@HectorAnadon
Copy link

Any update on this?

@steven200796
Copy link

Bump for my honors thesis XD

@jart
Copy link
Contributor
jart commented Apr 26, 2018

Assigning gradient issue to @girving

@jart jart assigned girving and unassigned jart Apr 26, 2018
@girving girving removed their assignment Apr 26, 2018
@yaroslavvb
Copy link
Contributor

I posted a stable implementation above which you can substitute instead of TF default version. Don't know if this works for an external contributor, that depends if there internal tests which may be broken by changing numerics of tf.norm

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Apr 27, 2018
@jart jart assigned yaroslavvb and unassigned jart May 11, 2018
@yaroslavvb
Copy link
Contributor

Marking as contributions welcome, I gave an example solution above, someone wanting to be a part of TF could work the solution into tensorflow.

@yaroslavvb yaroslavvb added the stat:contribution welcome Status - Contributions welcome label May 11, 2018
@yaroslavvb yaroslavvb removed their assignment May 11, 2018
@gsutanto
Copy link

Hi @yaroslavvb ,
Here's the problem:

import numpy as np
import tensorflow as tf

X = tf.placeholder(tf.float32, shape=(4,None))
Z = tf.norm(X)
var_grad = tf.gradients(Z, [X])

with tf.Session() as sess:
    X_ = np.array([
        [0],
        [0],
        [0],
        [0]
    ], dtype=np.float32)
    
    [Z_val, Z_grad] = sess.run([Z, var_grad], feed_dict={X: X_})
    
    print "Z_val = ", Z_val
    print "Z_grad = ", Z_grad

It will give output:

Z_val = 0.0
Z_grad = [array([[nan],
[nan],
[nan],
[nan]], dtype=float32)]

Using your suggestion:

from tensorflow.python.framework import function
import numpy as np
import tensorflow as tf

@function.Defun(tf.float32, tf.float32)
def norm_grad(x, dy):
    return dy*(x/(tf.norm(x, ord=2)))

@function.Defun(tf.float32, grad_func=norm_grad)
def norm(x):
    return tf.norm(x, ord=2)

X = tf.placeholder(tf.float32, shape=(4,None))
Z = norm(X)
var_grad = tf.gradients(Z, [X])

with tf.Session() as sess:
    X_ = np.array([
        [0],
        [0],
        [0],
        [0]
    ], dtype=np.float32)
    
    [Z_val, Z_grad] = sess.run([Z, var_grad], feed_dict={X: X_})
    
    print "Z_val = ", Z_val
    print "Z_grad = ", Z_grad

I still get:

Z_val = 0.0
Z_grad = [array([[nan],
[nan],
[nan],
[nan]], dtype=float32)]

So here's my simple modification that seems to work (just by adding a small epsilon to the gradient's denominator):

from tensorflow.python.framework import function
import numpy as np
import tensorflow as tf

@function.Defun(tf.float32, tf.float32)
def norm_grad(x, dy):
    return dy*(x/(tf.norm(x, ord=2)+1.0e-19))

@function.Defun(tf.float32, grad_func=norm_grad)
def norm(x):
    return tf.norm(x, ord=2)

X = tf.placeholder(tf.float32, shape=(4,None))
Z = norm(X)
var_grad = tf.gradients(Z, [X])

with tf.Session() as sess:
    X_ = np.array([
        [0],
        [0],
        [0],
        [0]
    ], dtype=np.float32)
    
    [Z_val, Z_grad] = sess.run([Z, var_grad], feed_dict={X: X_})
    
    print "Z_val = ", Z_val
    print "Z_grad = ", Z_grad

This will gives:

Z_val = 0.0
Z_grad = [array([[0.],
[0.],
[0.],
[0.]], dtype=float32)]

@gokul-uf
Copy link

@yaroslavvb Thanks for the snippet!. Unfortunately, with eager execution enabled,
I get the following error
AttributeError: Tensor.op is meaningless when eager execution is enabled.

The following works with eager execution enabled

@tf.custom_gradient
  def norm(x, axis=None, keep_dims=False):
      y = tf.norm(x, axis=axis, keep_dims=keep_dims)
  
      def grad(dy):
          return dy * (x / (y + 1e-19))
  
      return y, grad

@aayux
Copy link
aayux commented Sep 19, 2018

Thanks, @gsutanto for your solution. I wanted to ask what the correct way to force calculate gradients is when there is a complicated interaction of variables, such as when we're using gradient descent to optimize a loss function term with the norm as penalty.

Do I have to force the gradient of the entire loss function or should only the gradient of norm be computed separately?

Drawing from your code, what I'm saying is something like in the case below:

...
X = tf.placeholder(tf.float32, shape=(4, 1))
Y = tf.placeholder(tf.float32, shape=(1, 1))
W = tf.Variable(tf.zeros(shape=[1, 4]))

Z =  norm(X)
var_grad = tf.gradients(Z, [X])

J = tf.reduce_mean((Y - tf.matmul(W, X))) + (.5 * Z)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    X_ = np.array([
        [0],
        [0],
        [0],
        [0]
    ], dtype=np.float32)
    Y_ = np.array([
        [1]
    ], dtype=np.float32)
    
    optimizer = tf.train.GradientDescentOptimizer(1).minimize(J)                                                   
    sess.run([optimizer, J, var_grad], feed_dict={X: X_, Y: Y_})

@gsutanto
Copy link
gsutanto commented Sep 19, 2018

Hi, @Dust0x if you work on a batch of data, it would be something like:

from tensorflow.python.framework import function
import numpy as np
import tensorflow as tf

@function.Defun(tf.float32, tf.float32)
def norm_grad(x, dy):
    return dy*(x/(tf.norm(x, ord=2)+1.0e-19))

@function.Defun(tf.float32, grad_func=norm_grad)
def norm(x):
    return tf.norm(x, ord=2)

def norm_axis1(X):
    return tf.map_fn(lambda x: norm(x), X)

X = tf.placeholder(tf.float32, shape=(4,3))
Z = norm_axis1(X)
var_grad = tf.gradients(Z, [X])
norm_var_grad = norm_axis1(var_grad[0])

with tf.Session() as sess:
    X_ = np.array([
        [0,0,0],
        [1.,1.,1.],
        [0.,0.,0.],
        [2.,2.,2.]
    ], dtype=np.float32)
    
    [Z_val, Z_grad, norm_Z_grad] = sess.run([Z, var_grad, norm_var_grad], feed_dict={X: X_})
    
    print "Z_val = ", Z_val
    print "Z_grad = ", Z_grad
    print "norm_Z_grad = ", norm_Z_grad

(say if X_ were NxM=4x3; each row is a feature vector of size 3 that you would like to compute norm() on, and there are N=4 samples; there might be a better way to do this, if anyone knows, please post your code here; in terms of TensorFlow, I believe it will automatically use the newly defined norm gradient in any cost functions you are using it; please let me know if my understanding is incorrect...)

@aayux
Copy link
aayux commented Sep 20, 2018

Thanks again, @gsutanto. Luckily for me, the input to the norm only needs to be fed once in my case and so I don't (yet) have to worry about batch computations. The way I am currently solving this is by simply calling the forced gradient of norm in the same session as the optimizer and if what you say is correct — nothing seems to have broken so far, so it must be — this should work just fine.

@thebeancounter
Copy link

Hi, I am hitting this too when trying to use the pairwise distance norm of a batch of vectors, here is the SO question, am I getting this issue because of this bug?

h4ste pushed a commit to h4ste/oscar that referenced this issue Mar 2, 2019
@bkmi
Copy link
bkmi commented Apr 15, 2019

Hi, I am hitting this too when trying to use the pairwise distance norm of a batch of vectors, here is the SO question, am I getting this issue because of this bug?

I am getting the same bug on a distance norm of a batch of vectors. We have very similar situations.

--

EDIT: Although this is certainly not the same, I found that taking the absolute value was without the gradient instabilities. I.e. |x_1| + |x_2| + |x_3| + ...

@macarena-merida-floriano

Hi, @gsutanto, @gokul-uf , I've tried to implement your code, like this:

`@tf.custom_gradient
def custom_norm(x):
y = tf.norm(x, axis= - 1 , keepdims=False)
def grad(dy):
return dy * (x / (tf.expand_dims(y, -1) + 1e-19))
return y, grad

def custom_norm1(X):
return tf.map_fn(lambda x: custom_norm(x), X)`

I've add the tf.expand_dims(y, -1) in the gradient calculation and the axis= -1 in the norm computation because of my current problem: I'm working with batch of data so the tensor I have to compute the norm on has shape x = [2, 256000, 2] and I have to compute the norm on the last dimension, so the result should have shape x_norm = [2, 256000] (that's why I added the axis = -1 argument.
The problem comes when trying to compute the gradient, because of the shape. I've had to expand the dimensions of x_norm to be able to divide x / (x_norm + 1e-19), otherwise TensorFlow gives an error. So that calculus yields a tensor of shape [2, 256000, 1], but know I get another error:

ValueError: Dimensions must be equal, but are 256000 and 2 for 'gradients_30/map_7/while/IdentityN_grad/mul' (op: 'Mul') with input shapes: [256000], [256000,2].

I can tell it is beacuse know I have to multiply that tensor by dy and it has shape [256000]. I can't figure out how to solve this and being able to compute the norm on a tensor with batches, without mixing batches. Thank you in advance.

@Takonan
Copy link
Takonan commented Apr 18, 2020

@macarena1807, I believe you can apply tf.expand_dims(dy, -1) in order to make the dy shape match with the normalized x vector, so that the output of grad(dy) has the same shape as x:

@tf.custom_gradient
def custom_norm(x):
    y = tf.norm(x, axis= - 1 , keepdims=False)
    def grad(dy):
        return tf.expand_dims(dy, -1)*x/(tf.expand_dims(y, -1) + 1e-19)
    return y, grad

def custom_norm1(X):
    return tf.map_fn(lambda x: custom_norm(x), X)

@rmlarsen rmlarsen self-assigned this Jul 7, 2020
CloudyOverhead pushed a commit to gfabieno/GeoFlow that referenced this issue Dec 3, 2020
Prevent losses from diverging as per 
tensorflow/tensorflow#12071.
@carlosg-m
Copy link
carlosg-m commented Dec 16, 2020

Thank you for the solution @yaroslavvb, @gsutanto and @gokul-uf.

I needed the L2 Norm for a personal project and I was having a lot of issues. During the first epochs the network was stable and even able to converge, and then, suddenly the loss exploded to infinity, without any kind of warning or symptoms.

I testet every technique I knew to try to solve the problem without really knowing what was going on, including gradient clipping, regularization, different architectures, activation functions and initialization techniques, residual connections, different optimizers and learning rates, and so on, but nothing really worked. Then keeping the same structure I changed the Norm to L1 and everything worked fine, however L2 was a necessary requirement.

This solution completely solved the exploding gradient issue while using the L2 Norm, stabilized the training and provided the best results so far.

Imports:

from tensorflow.keras import layers, models, regularizers, activations
import tensorflow.keras.backend as K
import tensorflow as tf

Custom Keras L2 Norm Layer with custom gradient to prevent numerical instability:

  • I modified the part of adding a small epsilon to the gradient's denominator (y+1.0e-19) to K.maximum(y, K.epsilon()), replicating the same designed used in Keras loss functions.
@tf.custom_gradient
def l2_norm(x):
   y = tf.norm(x, ord='euclidean', keepdims=True, axis=1)
   def grad(dy):
       return dy*(x/K.maximum(y,K.epsilon()))
   return y, grad

Example of how to apply layer in a tf.Keras model:

def base_model():
    
  input_single = layers.Input(shape=(n,))
  x = layers.Dense(...)(input_single)
  x = layers.Dense(...)(x)    
    
  return tf.keras.Model(input_single, x)

def full_model():
  
  input_pair = layers.Input(shape=(2,n))
  x = layers.TimeDistributed(base_model())(input_pair)
  x = layers.Subtract()([x[:,0],x[:,1]])
  x = layers.Lambda(l2_norm)(x)

  return tf.keras.Model(input_pair, x)

@github-actions
Copy link

This issue is stale because it has been open for 180 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Mar 28, 2023
@jiqiujia
Copy link

I wonder if this issue still exists with tf.norm?It seems that I should use tf.nn.l2_normalize instead,since I see hanxiao has made a fix for it?

@google-ml-butler google-ml-butler bot removed the stale This label marks the issue/pr stale - to be closed automatically if no activity label Mar 11, 2024
@sushreebarsa
Copy link
Contributor

Hi,

A common and effective approach is to add a small epsilon value (ε) to the squared sum before taking the square root in the norm calculation. This prevents division by zero and improves stability for very small values. Here's how you can modify your code:

def safe_norm(x, epsilon=1e-12, axis=None):
    return tf.sqrt(tf.reduce_sum(x ** 2, axis=axis) + epsilon)

Z = safe_norm(X, ord='euclidean', axis=1, name='logit')

Since this issue has been open for a long time, the code/debug information for this issue may not be relevant with the current state of the code base.

The Tensorflow team is constantly improving the framework by fixing bugs and adding new features. We suggest you try the latest TensorFlow version with the latest compatible hardware configuration which could potentially resolve the issue. If you are still facing the issue, please create a new GitHub issue with your latest findings, with all the debugging information which could help us investigate.

Please follow the release notes to stay up to date with the latest developments which are happening in the Tensorflow space.

Thank you!

@sushreebarsa sushreebarsa added the stat:awaiting response Status - Awaiting response from author label May 23, 2024
@sushreebarsa sushreebarsa self-assigned this May 23, 2024
@oduerr
Copy link
Author
oduerr commented May 23, 2024

Hi,

If the standard implementation is still not working, it will still cause problems for users of the library.

@oduerr oduerr closed this as completed May 23, 2024
@oduerr
Copy link
Author
oduerr commented May 23, 2024

Hi,

If the standard implementation is still not working, it will still cause problems for users of the library. (I accidentally closed it in my comment before)

1 similar comment
@oduerr
Copy link
Author
oduerr commented May 23, 2024

Hi,

If the standard implementation is still not working, it will still cause problems for users of the library. (I accidentally closed it in my comment before)

@oduerr oduerr reopened this May 23, 2024
@google-ml-butler google-ml-butler bot removed the stat:awaiting response Status - Awaiting response from author label May 23, 2024
@sushreebarsa sushreebarsa removed their assignment May 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:contribution welcome Status - Contributions welcome
Projects
None yet
Development

No branches or pull requests