[go: nahoru, domu]

Migrate Hilt Workers to use Dagger Assisted Inject

Update the API from @WorkerInject to @HiltWorker similar to the migration from @ViewModelInject to @HiltViewModel. The main difference is that workers use Dagger's assisted inject which means their constructor have to be annotated with @AssistedInject and both the Context and WorkerParameters with @Assisted. These changes makes these two API have some parity as they behave similarly.

Test: Hilt Worker Integration App
Relnote: Replace @WorkerInject with @HiltWorker. @HiltWorker is now a type annotation and requires the usage of @AssistedInject in the constructor.
Change-Id: Ic2f15c63880a02ed082e9205fcad7acdb2b38751
diff --git a/hilt/hilt-common/api/current.txt b/hilt/hilt-common/api/current.txt
index c62b3cb..2d016ea 100644
--- a/hilt/hilt-common/api/current.txt
+++ b/hilt/hilt-common/api/current.txt
@@ -15,7 +15,7 @@
 
 package androidx.hilt.work {
 
-  @dagger.hilt.GeneratesRootInput @java.lang.annotation.Retention(java.lang.annotation.RetentionPolicy.CLASS) @java.lang.annotation.Target(java.lang.annotation.ElementType.CONSTRUCTOR) public @interface WorkerInject {
+  @dagger.hilt.GeneratesRootInput @java.lang.annotation.Retention(java.lang.annotation.RetentionPolicy.CLASS) @java.lang.annotation.Target(java.lang.annotation.ElementType.TYPE) public @interface HiltWorker {
   }
 
 }
diff --git a/hilt/hilt-common/api/public_plus_experimental_current.txt b/hilt/hilt-common/api/public_plus_experimental_current.txt
index c62b3cb..2d016ea 100644
--- a/hilt/hilt-common/api/public_plus_experimental_current.txt
+++ b/hilt/hilt-common/api/public_plus_experimental_current.txt
@@ -15,7 +15,7 @@
 
 package androidx.hilt.work {
 
-  @dagger.hilt.GeneratesRootInput @java.lang.annotation.Retention(java.lang.annotation.RetentionPolicy.CLASS) @java.lang.annotation.Target(java.lang.annotation.ElementType.CONSTRUCTOR) public @interface WorkerInject {
+  @dagger.hilt.GeneratesRootInput @java.lang.annotation.Retention(java.lang.annotation.RetentionPolicy.CLASS) @java.lang.annotation.Target(java.lang.annotation.ElementType.TYPE) public @interface HiltWorker {
   }
 
 }
diff --git a/hilt/hilt-common/api/restricted_current.txt b/hilt/hilt-common/api/restricted_current.txt
index c62b3cb..2d016ea 100644
--- a/hilt/hilt-common/api/restricted_current.txt
+++ b/hilt/hilt-common/api/restricted_current.txt
@@ -15,7 +15,7 @@
 
 package androidx.hilt.work {
 
-  @dagger.hilt.GeneratesRootInput @java.lang.annotation.Retention(java.lang.annotation.RetentionPolicy.CLASS) @java.lang.annotation.Target(java.lang.annotation.ElementType.CONSTRUCTOR) public @interface WorkerInject {
+  @dagger.hilt.GeneratesRootInput @java.lang.annotation.Retention(java.lang.annotation.RetentionPolicy.CLASS) @java.lang.annotation.Target(java.lang.annotation.ElementType.TYPE) public @interface HiltWorker {
   }
 
 }
diff --git a/hilt/hilt-common/src/main/java/androidx/hilt/Assisted.java b/hilt/hilt-common/src/main/java/androidx/hilt/Assisted.java
index 72e1e3a..a84b6d3 100644
--- a/hilt/hilt-common/src/main/java/androidx/hilt/Assisted.java
+++ b/hilt/hilt-common/src/main/java/androidx/hilt/Assisted.java
@@ -22,9 +22,7 @@
 import java.lang.annotation.Target;
 
 /**
- * Marks a parameter in a {@link androidx.hilt.lifecycle.ViewModelInject}-annotated constructor
- * or a {@link androidx.hilt.work.WorkerInject}-annotated constructor to be assisted
- * injected at runtime via a factory.
+ * Marks a parameter in a {@link androidx.hilt.lifecycle.ViewModelInject}-annotated constructor.
  *
  * @deprecated Use {@link dagger.assisted.Assisted}
  */
diff --git a/hilt/hilt-common/src/main/java/androidx/hilt/work/WorkerInject.java b/hilt/hilt-common/src/main/java/androidx/hilt/work/HiltWorker.java
similarity index 78%
rename from hilt/hilt-common/src/main/java/androidx/hilt/work/WorkerInject.java
rename to hilt/hilt-common/src/main/java/androidx/hilt/work/HiltWorker.java
index 7beca86..4508d16 100644
--- a/hilt/hilt-common/src/main/java/androidx/hilt/work/WorkerInject.java
+++ b/hilt/hilt-common/src/main/java/androidx/hilt/work/HiltWorker.java
@@ -26,17 +26,19 @@
 /**
  * Identifies a {@link androidx.work.ListenableWorker}'s constructor for injection.
  * <p>
- * Similar to {@link javax.inject.Inject}, a {@code Worker} containing a constructor annotated
- * with {@code WorkerInject} will have its dependencies defined in the constructor parameters
- * injected by Dagger's Hilt. The {@code Worker} will be available for creation by the
+ * The {@code Worker} will be available for creation by the
  * {@link androidx.hilt.work.HiltWorkerFactory} that should be set in {@code WorkManager}'s
  * configuration via
  * {@link androidx.work.Configuration.Builder#setWorkerFactory(androidx.work.WorkerFactory)}.
+ * The {@code HiltWorker} containing a constructor annotated with
+ * {@link dagger.assisted.AssistedInject} will have its dependencies defined in the constructor
+ * parameters injected by Dagger's Hilt.
  * <p>
  * Example:
  * <pre>
+ * &#64;HiltWorker
  * public class UploadWorker extends Worker {
- *     &#64;WorkerInject
+ *     &#64;AssistedInject
  *     public UploadWorker(&#64;Assisted Context context, &#64;Assisted WorkerParameters params,
  *             HttpClient httpClient) {
  *         // ...
@@ -57,17 +59,18 @@
  * }
  * </pre>
  * <p>
- * Only one constructor in the {@code Worker} must be annotated with {@code WorkerInject}. The
- * constructor must define parameters for a {@link androidx.hilt.Assisted}-annotated {@code Context}
- * and a {@link androidx.hilt.Assisted}-annotated {@code WorkerParameters} along with any other
+ * Only one constructor in the {@code Worker} must be annotated with
+ * {@link dagger.assisted.AssistedInject}. The constructor must define parameters for a
+ * {@link dagger.assisted.Assisted}-annotated {@code Context} and a
+ * {@link dagger.assisted.Assisted}-annotated {@code WorkerParameters} along with any other
  * dependencies. Both the {@code Context} and {@code WorkerParameters} must not be a type param
  * of {@link javax.inject.Provider} nor {@link dagger.Lazy} and must not be qualified.
  * <p>
  * Only dependencies available in the {@link dagger.hilt.components.SingletonComponent}
  * can be injected into the {@code Worker}.
  */
-@Target(ElementType.CONSTRUCTOR)
+@Target(ElementType.TYPE)
 @Retention(RetentionPolicy.CLASS)
 @GeneratesRootInput
-public @interface WorkerInject {
+public @interface HiltWorker {
 }
diff --git a/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/AndroidXHiltProcessor.kt b/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/AndroidXHiltProcessor.kt
index fb305895..e5c8c10 100644
--- a/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/AndroidXHiltProcessor.kt
+++ b/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/AndroidXHiltProcessor.kt
@@ -17,7 +17,7 @@
 package androidx.hilt
 
 import androidx.hilt.lifecycle.ViewModelInjectStep
-import androidx.hilt.work.WorkerInjectStep
+import androidx.hilt.work.WorkerStep
 import com.google.auto.service.AutoService
 import net.ltgt.gradle.incap.IncrementalAnnotationProcessor
 import net.ltgt.gradle.incap.IncrementalAnnotationProcessorType.ISOLATING
@@ -37,7 +37,7 @@
 
     override fun getSupportedAnnotationTypes() = setOf(
         ClassNames.VIEW_MODEL_INJECT.canonicalName(),
-        ClassNames.WORKER_INJECT.canonicalName()
+        ClassNames.HILT_WORKER.canonicalName()
     )
 
     override fun getSupportedSourceVersion() = SourceVersion.latest()
@@ -56,7 +56,7 @@
 
     private fun getSteps() = listOf(
         ViewModelInjectStep(processingEnv),
-        WorkerInjectStep(processingEnv)
+        WorkerStep(processingEnv)
     )
 
     interface Step {
diff --git a/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/ClassNames.kt b/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/ClassNames.kt
index c627b01..f6fbb5b 100644
--- a/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/ClassNames.kt
+++ b/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/ClassNames.kt
@@ -21,9 +21,13 @@
 internal object ClassNames {
     val ACTIVITY_RETAINED_COMPONENT =
         ClassName.get("dagger.hilt.android.components", "ActivityRetainedComponent")
-    val ASSISTED = ClassName.get("androidx.hilt", "Assisted")
+    val ANDROIDX_ASSISTED = ClassName.get("androidx.hilt", "Assisted")
+    val ASSISTED = ClassName.get("dagger.assisted", "Assisted")
+    val ASSISTED_FACTORY = ClassName.get("dagger.assisted", "AssistedFactory")
+    val ASSISTED_INJECT = ClassName.get("dagger.assisted", "AssistedInject")
     val BINDS = ClassName.get("dagger", "Binds")
     val CONTEXT = ClassName.get("android.content", "Context")
+    val HILT_WORKER = ClassName.get("androidx.hilt.work", "HiltWorker")
     val NON_NULL = ClassName.get("androidx.annotation", "NonNull")
     val INJECT = ClassName.get("javax.inject", "Inject")
     val INSTALL_IN = ClassName.get("dagger.hilt", "InstallIn")
@@ -42,6 +46,5 @@
     val STRING_KEY = ClassName.get("dagger.multibindings", "StringKey")
     val WORKER = ClassName.get("androidx.work", "Worker")
     val WORKER_ASSISTED_FACTORY = ClassName.get("androidx.hilt.work", "WorkerAssistedFactory")
-    val WORKER_INJECT = ClassName.get("androidx.hilt.work", "WorkerInject")
     val WORKER_PARAMETERS = ClassName.get("androidx.work", "WorkerParameters")
 }
diff --git a/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/assisted/DependencyRequest.kt b/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/assisted/DependencyRequest.kt
index eb4d5e6..a2edb27 100644
--- a/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/assisted/DependencyRequest.kt
+++ b/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/assisted/DependencyRequest.kt
@@ -59,7 +59,10 @@
     return DependencyRequest(
         name = simpleName.toString(),
         type = type,
-        isAssisted = hasAnnotation(ClassNames.ASSISTED.canonicalName()) && qualifier == null,
+        isAssisted = (
+            hasAnnotation(ClassNames.ANDROIDX_ASSISTED.canonicalName()) ||
+                hasAnnotation(ClassNames.ASSISTED.canonicalName())
+            ) && qualifier == null,
         qualifier = qualifier
     )
 }
\ No newline at end of file
diff --git a/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/lifecycle/ViewModelInjectStep.kt b/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/lifecycle/ViewModelInjectStep.kt
index c86a7bf..fadbe8e 100644
--- a/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/lifecycle/ViewModelInjectStep.kt
+++ b/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/lifecycle/ViewModelInjectStep.kt
@@ -121,7 +121,7 @@
                 valid = false
             }
             firstOrNull()?.let {
-                if (!it.hasAnnotation(ClassNames.ASSISTED.canonicalName())) {
+                if (!it.hasAnnotation(ClassNames.ANDROIDX_ASSISTED.canonicalName())) {
                     error("Missing @Assisted annotation in param '${it.simpleName}'.", it)
                     valid = false
                 }
diff --git a/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/work/WorkerInjectElements.kt b/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/work/WorkerElements.kt
similarity index 97%
rename from hilt/hilt-compiler/src/main/kotlin/androidx/hilt/work/WorkerInjectElements.kt
rename to hilt/hilt-compiler/src/main/kotlin/androidx/hilt/work/WorkerElements.kt
index 03f4c05..0b6c84d 100644
--- a/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/work/WorkerInjectElements.kt
+++ b/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/work/WorkerElements.kt
@@ -27,7 +27,7 @@
 /**
  * Data class that represents a Hilt injected Worker
  */
-internal data class WorkerInjectElements(
+internal data class WorkerElements(
     val typeElement: TypeElement,
     val constructorElement: ExecutableElement
 ) {
diff --git a/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/work/WorkerGenerator.kt b/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/work/WorkerGenerator.kt
index 3552dfc..4052322 100644
--- a/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/work/WorkerGenerator.kt
+++ b/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/work/WorkerGenerator.kt
@@ -17,7 +17,6 @@
 package androidx.hilt.work
 
 import androidx.hilt.ClassNames
-import androidx.hilt.assisted.AssistedFactoryGenerator
 import androidx.hilt.ext.S
 import androidx.hilt.ext.T
 import androidx.hilt.ext.addGeneratedAnnotation
@@ -46,40 +45,27 @@
  * ```
  * and
  * ```
- * public final class $_AssistedFactory extends WorkerAssistedFactory<$> {
+ * @AssistedFactory
+ * public interface $_AssistedFactory extends WorkerAssistedFactory<$> {
  *
- *   private final Provider<Dep1> dep1;
- *   private final Provider<Dep2> dep2;
- *   ...
- *
- *   @Inject
- *   $_AssistedFactory(Provider<Dep1> dep1, Provider<Dep2> dep2, ...) {
- *     this.dep1 = dep1;
- *     this.dep2 = dep2;
- *     ...
- *   }
- *
- *   @Override
- *   @NonNull
- *   public $ create(@NonNull Context context, @NonNull WorkerParameter params) {
- *     return new $(context, params, dep1.get(), dep2.get());
- *   }
  * }
  * ```
  */
 internal class WorkerGenerator(
     private val processingEnv: ProcessingEnvironment,
-    private val injectedWorker: WorkerInjectElements
+    private val injectedWorker: WorkerElements
 ) {
     fun generate() {
-        AssistedFactoryGenerator(
-            processingEnv = processingEnv,
-            productClassName = injectedWorker.className,
-            factoryClassName = injectedWorker.factoryClassName,
-            factorySuperTypeName = injectedWorker.factorySuperTypeName,
-            originatingElement = injectedWorker.typeElement,
-            dependencyRequests = injectedWorker.dependencyRequests
-        ).generate()
+        val assistedFactoryTypeSpec = TypeSpec.interfaceBuilder(injectedWorker.factoryClassName)
+            .addOriginatingElement(injectedWorker.typeElement)
+            .addGeneratedAnnotation(processingEnv.elementUtils, processingEnv.sourceVersion)
+            .addAnnotation(ClassNames.ASSISTED_FACTORY)
+            .addModifiers(Modifier.PUBLIC)
+            .addSuperinterface(injectedWorker.factorySuperTypeName)
+            .build()
+        JavaFile.builder(injectedWorker.factoryClassName.packageName(), assistedFactoryTypeSpec)
+            .build()
+            .writeTo(processingEnv.filer)
 
         val hiltModuleTypeSpec = TypeSpec.interfaceBuilder(injectedWorker.moduleClassName)
             .addOriginatingElement(injectedWorker.typeElement)
diff --git a/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/work/WorkerInjectStep.kt b/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/work/WorkerInjectStep.kt
deleted file mode 100644
index 346f308..0000000
--- a/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/work/WorkerInjectStep.kt
+++ /dev/null
@@ -1,159 +0,0 @@
-/*
- * Copyright 2020 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package androidx.hilt.work
-
-import androidx.hilt.AndroidXHiltProcessor
-import androidx.hilt.ClassNames
-import androidx.hilt.ext.hasAnnotation
-import com.google.auto.common.MoreElements
-import com.squareup.javapoet.TypeName
-import javax.annotation.processing.ProcessingEnvironment
-import javax.lang.model.element.Element
-import javax.lang.model.element.ExecutableElement
-import javax.lang.model.element.Modifier
-import javax.lang.model.element.NestingKind
-import javax.lang.model.element.TypeElement
-import javax.lang.model.util.ElementFilter
-import javax.tools.Diagnostic
-
-/**
- * Processing step that generates code enabling assisted injection of Workers using Hilt.
- */
-class WorkerInjectStep(
-    private val processingEnv: ProcessingEnvironment
-) : AndroidXHiltProcessor.Step {
-
-    private val elements = processingEnv.elementUtils
-    private val types = processingEnv.typeUtils
-    private val messager = processingEnv.messager
-
-    override fun annotation() = ClassNames.WORKER_INJECT.canonicalName()
-
-    override fun process(annotatedElements: Set<Element>) {
-        val parsedElements = mutableSetOf<TypeElement>()
-        annotatedElements.forEach { element ->
-            val constructorElement =
-                MoreElements.asExecutable(element)
-            val typeElement =
-                MoreElements.asType(constructorElement.enclosingElement)
-            if (parsedElements.add(typeElement)) {
-                parse(typeElement, constructorElement)?.let { worker ->
-                    WorkerGenerator(
-                        processingEnv,
-                        worker
-                    ).generate()
-                }
-            }
-        }
-    }
-
-    private fun parse(
-        typeElement: TypeElement,
-        constructorElement: ExecutableElement
-    ): WorkerInjectElements? {
-        var valid = true
-
-        if (elements.getTypeElement(ClassNames.WORKER_ASSISTED_FACTORY.toString()) == null) {
-            error(
-                "To use @WorkerInject you must add the 'work' artifact. " +
-                    "androidx.hilt:hilt-work:<version>"
-            )
-            valid = false
-        }
-
-        if (!types.isSubtype(
-                typeElement.asType(),
-                elements.getTypeElement(ClassNames.LISTENABLE_WORKER.toString()).asType()
-            )
-        ) {
-            error(
-                "@WorkerInject is only supported on types that subclass " +
-                    "${ClassNames.LISTENABLE_WORKER}."
-            )
-            valid = false
-        }
-
-        ElementFilter.constructorsIn(typeElement.enclosedElements).filter {
-            it.hasAnnotation(ClassNames.WORKER_INJECT.canonicalName())
-        }.let { constructors ->
-            if (constructors.size > 1) {
-                error("Multiple @WorkerInject annotated constructors found.", typeElement)
-                valid = false
-            }
-            constructors.filter { it.modifiers.contains(Modifier.PRIVATE) }.forEach {
-                error("@WorkerInject annotated constructors must not be private.", it)
-                valid = false
-            }
-        }
-
-        if (typeElement.nestingKind == NestingKind.MEMBER &&
-            !typeElement.modifiers.contains(Modifier.STATIC)
-        ) {
-            error(
-                "@WorkerInject may only be used on inner classes if they are static.",
-                typeElement
-            )
-            valid = false
-        }
-
-        constructorElement.parameters.filter {
-            TypeName.get(it.asType()) == ClassNames.CONTEXT
-        }.apply {
-            if (size != 1) {
-                error(
-                    "Expected exactly one constructor argument of type " +
-                        "${ClassNames.CONTEXT}, found $size",
-                    constructorElement
-                )
-                valid = false
-            }
-            firstOrNull()?.let {
-                if (!it.hasAnnotation(ClassNames.ASSISTED.canonicalName())) {
-                    error("Missing @Assisted annotation in param '${it.simpleName}'.", it)
-                    valid = false
-                }
-            }
-        }
-
-        constructorElement.parameters.filter {
-            TypeName.get(it.asType()) == ClassNames.WORKER_PARAMETERS
-        }.apply {
-            if (size != 1) {
-                error(
-                    "Expected exactly one constructor argument of type " +
-                        "${ClassNames.WORKER_PARAMETERS}, found $size",
-                    constructorElement
-                )
-                valid = false
-            }
-            firstOrNull()?.let {
-                if (!it.hasAnnotation(ClassNames.ASSISTED.canonicalName())) {
-                    error("Missing @Assisted annotation in param '${it.simpleName}'.", it)
-                    valid = false
-                }
-            }
-        }
-
-        if (!valid) return null
-
-        return WorkerInjectElements(typeElement, constructorElement)
-    }
-
-    private fun error(message: String, element: Element? = null) {
-        messager.printMessage(Diagnostic.Kind.ERROR, message, element)
-    }
-}
\ No newline at end of file
diff --git a/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/work/WorkerStep.kt b/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/work/WorkerStep.kt
new file mode 100644
index 0000000..0fbdfa3
--- /dev/null
+++ b/hilt/hilt-compiler/src/main/kotlin/androidx/hilt/work/WorkerStep.kt
@@ -0,0 +1,153 @@
+/*
+ * Copyright 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.hilt.work
+
+import androidx.hilt.AndroidXHiltProcessor
+import androidx.hilt.ClassNames
+import androidx.hilt.ext.hasAnnotation
+import com.google.auto.common.MoreElements
+import com.squareup.javapoet.TypeName
+import javax.annotation.processing.ProcessingEnvironment
+import javax.lang.model.element.Element
+import javax.lang.model.element.Modifier
+import javax.lang.model.element.NestingKind
+import javax.lang.model.element.TypeElement
+import javax.lang.model.util.ElementFilter
+import javax.tools.Diagnostic
+
+/**
+ * Processing step that generates code enabling assisted injection of Workers using Hilt.
+ */
+class WorkerStep(
+    private val processingEnv: ProcessingEnvironment
+) : AndroidXHiltProcessor.Step {
+
+    private val elements = processingEnv.elementUtils
+    private val types = processingEnv.typeUtils
+    private val messager = processingEnv.messager
+
+    override fun annotation() = ClassNames.HILT_WORKER.canonicalName()
+
+    override fun process(annotatedElements: Set<Element>) {
+        val parsedElements = mutableSetOf<TypeElement>()
+        annotatedElements.forEach { element ->
+            val typeElement = MoreElements.asType(element)
+            if (parsedElements.add(typeElement)) {
+                parse(typeElement)?.let { worker ->
+                    WorkerGenerator(
+                        processingEnv,
+                        worker
+                    ).generate()
+                }
+            }
+        }
+    }
+
+    private fun parse(typeElement: TypeElement): WorkerElements? {
+        var valid = true
+
+        if (elements.getTypeElement(ClassNames.WORKER_ASSISTED_FACTORY.toString()) == null) {
+            error(
+                "To use @HiltWorker you must add the 'work' artifact. " +
+                    "androidx.hilt:hilt-work:<version>"
+            )
+            valid = false
+        }
+
+        if (!types.isSubtype(
+                typeElement.asType(),
+                elements.getTypeElement(ClassNames.LISTENABLE_WORKER.toString()).asType()
+            )
+        ) {
+            error(
+                "@HiltWorker is only supported on types that subclass " +
+                    "${ClassNames.LISTENABLE_WORKER}."
+            )
+            valid = false
+        }
+
+        val constructors = ElementFilter.constructorsIn(typeElement.enclosedElements).filter {
+            if (it.hasAnnotation(ClassNames.INJECT.canonicalName())) {
+                error(
+                    "Worker constructor should be annotated with @AssistedInject instead of " +
+                        "@Inject."
+                )
+                valid = false
+            }
+            it.hasAnnotation(ClassNames.ASSISTED_INJECT.canonicalName())
+        }
+        if (constructors.size != 1) {
+            error(
+                "@HiltWorker annotated class should contain exactly one @AssistedInject " +
+                    "annotated constructor.",
+                typeElement
+            )
+            valid = false
+        }
+        constructors.filter { it.modifiers.contains(Modifier.PRIVATE) }.forEach {
+            error("@AssistedInject annotated constructors must not be private.", it)
+            valid = false
+        }
+
+        if (typeElement.nestingKind == NestingKind.MEMBER &&
+            !typeElement.modifiers.contains(Modifier.STATIC)
+        ) {
+            error(
+                "@HiltWorker may only be used on inner classes if they are static.",
+                typeElement
+            )
+            valid = false
+        }
+
+        if (!valid) return null
+
+        val injectConstructor = constructors.first()
+        var contextIndex = -1
+        var workerParametersIndex = -1
+        injectConstructor.parameters.forEachIndexed { index, param ->
+            if (TypeName.get(param.asType()) == ClassNames.CONTEXT) {
+                if (!param.hasAnnotation(ClassNames.ASSISTED.canonicalName())) {
+                    error("Missing @Assisted annotation in param '${param.simpleName}'.", param)
+                    valid = false
+                }
+                contextIndex = index
+            }
+            if (TypeName.get(param.asType()) == ClassNames.WORKER_PARAMETERS) {
+                if (!param.hasAnnotation(ClassNames.ASSISTED.canonicalName())) {
+                    error("Missing @Assisted annotation in param '${param.simpleName}'.", param)
+                    valid = false
+                }
+                workerParametersIndex = index
+            }
+        }
+        if (contextIndex > workerParametersIndex) {
+            error(
+                "The 'Context' parameter must be declared before the 'WorkerParameters' in the " +
+                    "@AssistedInject constructor of a @HiltWorker annotated class.",
+                injectConstructor
+            )
+        }
+
+        if (!valid) return null
+
+        return WorkerElements(typeElement, injectConstructor)
+    }
+
+    private fun error(message: String, element: Element? = null) {
+        messager.printMessage(Diagnostic.Kind.ERROR, message, element)
+    }
+}
\ No newline at end of file
diff --git a/hilt/hilt-compiler/src/test/kotlin/androidx/hilt/work/WorkerGeneratorTest.kt b/hilt/hilt-compiler/src/test/kotlin/androidx/hilt/work/WorkerGeneratorTest.kt
index 9d87edb..fc892a0 100644
--- a/hilt/hilt-compiler/src/test/kotlin/androidx/hilt/work/WorkerGeneratorTest.kt
+++ b/hilt/hilt-compiler/src/test/kotlin/androidx/hilt/work/WorkerGeneratorTest.kt
@@ -41,14 +41,16 @@
         package androidx.hilt.work.test;
 
         import android.content.Context;
-        import androidx.hilt.Assisted;
-        import androidx.hilt.work.WorkerInject;
+        import androidx.hilt.work.HiltWorker;
         import androidx.work.Worker;
         import androidx.work.WorkerParameters;
+        import dagger.assisted.Assisted;
+        import dagger.assisted.AssistedInject;
         import java.lang.String;
 
+        @HiltWorker
         class MyWorker extends Worker {
-            @WorkerInject
+            @AssistedInject
             MyWorker(@Assisted Context context, @Assisted WorkerParameters params, String s,
                     Foo f, long l) {
                 super(context, params);
@@ -59,37 +61,13 @@
         val expected = """
         package androidx.hilt.work.test;
 
-        import android.content.Context;
-        import androidx.annotation.NonNull;
         import androidx.hilt.work.WorkerAssistedFactory;
-        import androidx.work.WorkerParameters;
-        import java.lang.Long;
-        import java.lang.Override;
-        import java.lang.String;
+        import dagger.assisted.AssistedFactory;
         import $GENERATED_TYPE;
-        import javax.inject.Inject;
-        import javax.inject.Provider;
 
         $GENERATED_ANNOTATION
-        public final class MyWorker_AssistedFactory implements
-                WorkerAssistedFactory<MyWorker> {
-
-            private final Provider<String> s;
-            private final Provider<Foo> f;
-            private final Provider<Long> l;
-
-            @Inject
-            MyWorker_AssistedFactory(Provider<String> s, Provider<Foo> f, Provider<Long> l) {
-                this.s = s;
-                this.f = f;
-                this.l = l;
-            }
-
-            @Override
-            @NonNull
-            public MyWorker create(Context context, WorkerParameters parameters) {
-                return new MyWorker(context, parameters, s.get(), f.get(), l.get());
-            }
+        @AssistedFactory
+        public interface MyWorker_AssistedFactory extends WorkerAssistedFactory<MyWorker> {
         }
         """.toJFO("androidx.hilt.work.test.MyWorker_AssistedFactory")
 
@@ -111,13 +89,15 @@
         package androidx.hilt.work.test;
 
         import android.content.Context;
-        import androidx.hilt.Assisted;
-        import androidx.hilt.work.WorkerInject;
+        import androidx.hilt.work.HiltWorker;
         import androidx.work.Worker;
         import androidx.work.WorkerParameters;
+        import dagger.assisted.Assisted;
+        import dagger.assisted.AssistedInject;
 
+        @HiltWorker
         class MyWorker extends Worker {
-            @WorkerInject
+            @AssistedInject
             MyWorker(@Assisted Context context, @Assisted WorkerParameters params) {
                 super(context, params);
             }
diff --git a/hilt/hilt-compiler/src/test/kotlin/androidx/hilt/work/WorkerInjectStepTest.kt b/hilt/hilt-compiler/src/test/kotlin/androidx/hilt/work/WorkerInjectStepTest.kt
deleted file mode 100644
index 0f83fb6..0000000
--- a/hilt/hilt-compiler/src/test/kotlin/androidx/hilt/work/WorkerInjectStepTest.kt
+++ /dev/null
@@ -1,155 +0,0 @@
-/*
- * Copyright 2020 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package androidx.hilt.work
-
-import androidx.hilt.Sources
-import androidx.hilt.compiler
-import androidx.hilt.toJFO
-import com.google.testing.compile.CompilationSubject.assertThat
-import org.junit.Test
-import org.junit.runner.RunWith
-import org.junit.runners.JUnit4
-
-@RunWith(JUnit4::class)
-class WorkerInjectStepTest {
-
-    @Test
-    fun verifyEnclosingElementExtendsWorker() {
-        val myWorker = """
-        package androidx.hilt.work.test;
-
-        import android.content.Context;
-        import androidx.hilt.Assisted;
-        import androidx.hilt.work.WorkerInject;
-        import androidx.work.WorkerParameters;
-
-        class MyWorker {
-            @WorkerInject
-            MyWorker(@Assisted Context context, @Assisted WorkerParameters params) { }
-        }
-        """.toJFO("androidx.hilt.work.work.MyWorker")
-
-        val compilation = compiler()
-            .compile(myWorker, Sources.LISTENABLE_WORKER, Sources.WORKER, Sources.WORKER_PARAMETERS)
-        assertThat(compilation).apply {
-            failed()
-            hadErrorCount(1)
-            hadErrorContainingMatch(
-                "@WorkerInject is only supported on types that subclass " +
-                    "androidx.work.ListenableWorker."
-            )
-        }
-    }
-
-    @Test
-    fun verifySingleAnnotatedConstructor() {
-        val myWorker = """
-        package androidx.hilt.work.test;
-
-        import android.content.Context;
-        import androidx.hilt.Assisted;
-        import androidx.hilt.work.WorkerInject;
-        import androidx.work.Worker;
-        import androidx.work.WorkerParameters;
-        import java.lang.String;
-
-        class MyWorker extends Worker {
-            @WorkerInject
-            MyWorker(@Assisted Context context, @Assisted WorkerParameters params) {
-                super(context, params);
-            }
-
-            @WorkerInject
-            MyWorker(Context context, WorkerParameters params, String s) {
-                super(context, params);
-            }
-        }
-        """.toJFO("androidx.hilt.work.test.MyWorker")
-
-        val compilation = compiler()
-            .compile(myWorker, Sources.LISTENABLE_WORKER, Sources.WORKER, Sources.WORKER_PARAMETERS)
-        assertThat(compilation).apply {
-            failed()
-            hadErrorCount(1)
-            hadErrorContainingMatch("Multiple @WorkerInject annotated constructors found.")
-        }
-    }
-
-    @Test
-    fun verifyNonPrivateConstructor() {
-        val myWorker = """
-        package androidx.hilt.work.test;
-
-        import android.content.Context;
-        import androidx.hilt.Assisted;
-        import androidx.hilt.work.WorkerInject;
-        import androidx.work.Worker;
-        import androidx.work.WorkerParameters;
-
-        class MyWorker extends Worker {
-            @WorkerInject
-            private MyWorker(@Assisted Context context, @Assisted WorkerParameters params) {
-                super(context, params);
-            }
-        }
-        """.toJFO("androidx.hilt.work.test.MyWorker")
-
-        val compilation = compiler()
-            .compile(myWorker, Sources.LISTENABLE_WORKER, Sources.WORKER, Sources.WORKER_PARAMETERS)
-        assertThat(compilation).apply {
-            failed()
-            hadErrorCount(1)
-            hadErrorContainingMatch(
-                "@WorkerInject annotated constructors must not be " +
-                    "private."
-            )
-        }
-    }
-
-    @Test
-    fun verifyInnerClassIsStatic() {
-        val myWorker = """
-        package androidx.hilt.work.test;
-
-        import android.content.Context;
-        import androidx.hilt.Assisted;
-        import androidx.hilt.work.WorkerInject;
-        import androidx.work.Worker;
-        import androidx.work.WorkerParameters;
-
-        class Outer {
-            class MyWorker extends Worker {
-                @WorkerInject
-                MyWorker(@Assisted Context context, @Assisted WorkerParameters params) {
-                    super(context, params);
-                }
-            }
-        }
-        """.toJFO("androidx.hilt.work.test.Outer")
-
-        val compilation = compiler()
-            .compile(myWorker, Sources.LISTENABLE_WORKER, Sources.WORKER, Sources.WORKER_PARAMETERS)
-        assertThat(compilation).apply {
-            failed()
-            hadErrorCount(1)
-            hadErrorContainingMatch(
-                "@WorkerInject may only be used on inner classes " +
-                    "if they are static."
-            )
-        }
-    }
-}
\ No newline at end of file
diff --git a/hilt/hilt-compiler/src/test/kotlin/androidx/hilt/work/WorkerStepTest.kt b/hilt/hilt-compiler/src/test/kotlin/androidx/hilt/work/WorkerStepTest.kt
new file mode 100644
index 0000000..9783c14
--- /dev/null
+++ b/hilt/hilt-compiler/src/test/kotlin/androidx/hilt/work/WorkerStepTest.kt
@@ -0,0 +1,231 @@
+/*
+ * Copyright 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.hilt.work
+
+import androidx.hilt.Sources
+import androidx.hilt.compiler
+import androidx.hilt.toJFO
+import com.google.testing.compile.CompilationSubject.assertThat
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+
+@RunWith(JUnit4::class)
+class WorkerStepTest {
+
+    @Test
+    fun verifyEnclosingElementExtendsWorker() {
+        val myWorker = """
+        package androidx.hilt.work.test;
+
+        import android.content.Context;
+        import androidx.hilt.work.HiltWorker;
+        import androidx.work.WorkerParameters;
+        import dagger.assisted.Assisted;
+        import dagger.assisted.AssistedInject;
+
+        @HiltWorker
+        class MyWorker {
+            @AssistedInject
+            MyWorker(@Assisted Context context, @Assisted WorkerParameters params) { }
+        }
+        """.toJFO("androidx.hilt.work.work.MyWorker")
+
+        val compilation = compiler()
+            .compile(myWorker, Sources.LISTENABLE_WORKER, Sources.WORKER, Sources.WORKER_PARAMETERS)
+        assertThat(compilation).apply {
+            failed()
+            hadErrorCount(1)
+            hadErrorContainingMatch(
+                "@HiltWorker is only supported on types that subclass " +
+                    "androidx.work.ListenableWorker."
+            )
+        }
+    }
+
+    @Test
+    fun verifySingleAnnotatedConstructor() {
+        val myWorker = """
+        package androidx.hilt.work.test;
+
+        import android.content.Context;
+        import androidx.hilt.work.HiltWorker;
+        import androidx.work.Worker;
+        import androidx.work.WorkerParameters;
+        import dagger.assisted.Assisted;
+        import dagger.assisted.AssistedInject;
+        import java.lang.String;
+
+        @HiltWorker
+        class MyWorker extends Worker {
+            @AssistedInject
+            MyWorker(@Assisted Context context, @Assisted WorkerParameters params) {
+                super(context, params);
+            }
+
+            @AssistedInject
+            MyWorker(Context context, WorkerParameters params, String s) {
+                super(context, params);
+            }
+        }
+        """.toJFO("androidx.hilt.work.test.MyWorker")
+
+        val compilation = compiler()
+            .compile(myWorker, Sources.LISTENABLE_WORKER, Sources.WORKER, Sources.WORKER_PARAMETERS)
+        assertThat(compilation).apply {
+            failed()
+            hadErrorCount(1)
+            hadErrorContainingMatch(
+                "@HiltWorker annotated class should contain exactly one @AssistedInject " +
+                    "annotated constructor."
+            )
+        }
+    }
+
+    @Test
+    fun verifyNonPrivateConstructor() {
+        val myWorker = """
+        package androidx.hilt.work.test;
+
+        import android.content.Context;
+        import androidx.hilt.work.HiltWorker;
+        import androidx.work.Worker;
+        import androidx.work.WorkerParameters;
+        import dagger.assisted.Assisted;
+        import dagger.assisted.AssistedInject;
+
+        @HiltWorker
+        class MyWorker extends Worker {
+            @AssistedInject
+            private MyWorker(@Assisted Context context, @Assisted WorkerParameters params) {
+                super(context, params);
+            }
+        }
+        """.toJFO("androidx.hilt.work.test.MyWorker")
+
+        val compilation = compiler()
+            .compile(myWorker, Sources.LISTENABLE_WORKER, Sources.WORKER, Sources.WORKER_PARAMETERS)
+        assertThat(compilation).apply {
+            failed()
+            hadErrorCount(1)
+            hadErrorContainingMatch(
+                "@AssistedInject annotated constructors must not be private."
+            )
+        }
+    }
+
+    @Test
+    fun verifyInnerClassIsStatic() {
+        val myWorker = """
+        package androidx.hilt.work.test;
+
+        import android.content.Context;
+        import androidx.hilt.work.HiltWorker;
+        import androidx.work.Worker;
+        import androidx.work.WorkerParameters;
+        import dagger.assisted.Assisted;
+        import dagger.assisted.AssistedInject;
+
+        class Outer {
+            @HiltWorker
+            class MyWorker extends Worker {
+                @AssistedInject
+                MyWorker(@Assisted Context context, @Assisted WorkerParameters params) {
+                    super(context, params);
+                }
+            }
+        }
+        """.toJFO("androidx.hilt.work.test.Outer")
+
+        val compilation = compiler()
+            .compile(myWorker, Sources.LISTENABLE_WORKER, Sources.WORKER, Sources.WORKER_PARAMETERS)
+        assertThat(compilation).apply {
+            failed()
+            hadErrorCount(1)
+            hadErrorContainingMatch(
+                "@HiltWorker may only be used on inner classes " +
+                    "if they are static."
+            )
+        }
+    }
+
+    @Test
+    fun verifyConstructorAnnotation() {
+        val myWorker = """
+        package androidx.hilt.work.test;
+
+        import android.content.Context;
+        import androidx.hilt.work.HiltWorker;
+        import androidx.work.Worker;
+        import androidx.work.WorkerParameters;
+        import dagger.assisted.Assisted;
+        import dagger.assisted.AssistedInject;
+        import java.lang.String;
+        import javax.inject.Inject;
+
+        @HiltWorker
+        class MyWorker extends Worker {
+            @Inject
+            MyWorker(@Assisted Context context, @Assisted WorkerParameters params) {
+                super(context, params);
+            }
+        }
+        """.toJFO("androidx.hilt.work.test.MyWorker")
+
+        val compilation = compiler()
+            .compile(myWorker, Sources.LISTENABLE_WORKER, Sources.WORKER, Sources.WORKER_PARAMETERS)
+        assertThat(compilation).apply {
+            failed()
+            hadErrorContainingMatch(
+                "Worker constructor should be annotated with @AssistedInject instead of @Inject."
+            )
+        }
+    }
+
+    @Test
+    fun verifyAssistedParamOrder() {
+        val myWorker = """
+        package androidx.hilt.work.test;
+
+        import android.content.Context;
+        import androidx.hilt.work.HiltWorker;
+        import androidx.work.Worker;
+        import androidx.work.WorkerParameters;
+        import dagger.assisted.Assisted;
+        import dagger.assisted.AssistedInject;
+        import java.lang.String;
+
+        @HiltWorker
+        class MyWorker extends Worker {
+            @AssistedInject
+            MyWorker(@Assisted WorkerParameters params, @Assisted Context context) {
+                super(context, params);
+            }
+        }
+        """.toJFO("androidx.hilt.work.test.MyWorker")
+
+        val compilation = compiler()
+            .compile(myWorker, Sources.LISTENABLE_WORKER, Sources.WORKER, Sources.WORKER_PARAMETERS)
+        assertThat(compilation).apply {
+            failed()
+            hadErrorContainingMatch(
+                "The 'Context' parameter must be declared before the 'WorkerParameters' in the " +
+                    "@AssistedInject constructor of a @HiltWorker annotated class.",
+            )
+        }
+    }
+}
\ No newline at end of file
diff --git a/hilt/hilt-work/proguard-rules.pro b/hilt/hilt-work/proguard-rules.pro
index 553d9ca..cd5c2e5 100644
--- a/hilt/hilt-work/proguard-rules.pro
+++ b/hilt/hilt-work/proguard-rules.pro
@@ -1,5 +1,2 @@
 # Keep class names of Hilt injected Workers since their name are used as a multibinding map key.
--keepclasseswithmembernames class * extends androidx.work.ListenableWorker {
-    @androidx.hilt.work.WorkerInject
-    <init>(...);
-}
\ No newline at end of file
+-keepnames @androidx.hilt.work.HiltWorker class * extends androidx.work.ListenableWorker
\ No newline at end of file
diff --git a/hilt/integration-tests/workerapp/src/main/java/androidx/hilt/integration/workerapp/SimpleWorker.kt b/hilt/integration-tests/workerapp/src/main/java/androidx/hilt/integration/workerapp/SimpleWorker.kt
index cc4c474..5fc9f80 100644
--- a/hilt/integration-tests/workerapp/src/main/java/androidx/hilt/integration/workerapp/SimpleWorker.kt
+++ b/hilt/integration-tests/workerapp/src/main/java/androidx/hilt/integration/workerapp/SimpleWorker.kt
@@ -18,14 +18,16 @@
 
 import android.content.Context
 import android.util.Log
-import androidx.hilt.Assisted
-import androidx.hilt.work.WorkerInject
+import androidx.hilt.work.HiltWorker
 import androidx.work.CoroutineWorker
 import androidx.work.Worker
 import androidx.work.WorkerParameters
+import dagger.assisted.Assisted
+import dagger.assisted.AssistedInject
 import javax.inject.Inject
 
-class SimpleWorker @WorkerInject constructor(
+@HiltWorker
+class SimpleWorker @AssistedInject constructor(
     @Assisted context: Context,
     @Assisted params: WorkerParameters,
     private val logger: MyLogger
@@ -36,7 +38,8 @@
     }
 }
 
-class SimpleCoroutineWorker @WorkerInject constructor(
+@HiltWorker
+class SimpleCoroutineWorker @AssistedInject constructor(
     @Assisted context: Context,
     @Assisted params: WorkerParameters,
     private val logger: MyLogger
@@ -48,7 +51,8 @@
 }
 
 object TopClass {
-    class NestedWorker @WorkerInject constructor(
+    @HiltWorker
+    class NestedWorker @AssistedInject constructor(
         @Assisted context: Context,
         @Assisted params: WorkerParameters,
         private val logger: MyLogger