[go: nahoru, domu]

Skip to content

Commit

Permalink
Export sentence classifier to CoreML
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon Palombo authored and Zach Nation committed Dec 25, 2017
1 parent 6402546 commit 40e81a3
Showing 1 changed file with 39 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,12 @@
from __future__ import division as _
from __future__ import absolute_import as _
import turicreate as _tc
import re as _re
import os as _os
from turicreate.toolkits._model import CustomModel as _CustomModel
from turicreate.toolkits._model import PythonProxy as _PythonProxy
from turicreate.toolkits._internal_utils import _toolkit_repr_print
import operator as _operator
import turicreate as _turicreate
import logging as _logging


def _BOW_FEATURE_EXTRACTOR(sf):
"""
Return an SFrame containing a bag of words representation of each column.
Expand Down Expand Up @@ -125,7 +122,6 @@ def __init__(self, state=None):

self.__proxy__ = _PythonProxy(state)


@classmethod
def _native_name(cls):
return "sentence_classifier"
Expand All @@ -146,7 +142,7 @@ def _load_version(self, state, version):
state = _PythonProxy(state)
return SentenceClassifier(state)

def predict(self, dataset):
def predict(self, dataset, output_type='class'):
"""
Return predictions for ``dataset``, using the trained model.
Expand All @@ -157,6 +153,16 @@ def predict(self, dataset):
names as the features used for model training, but does not require
a target column. Additional columns are ignored.
output_type : {'class', 'probability_vector'}, optional
Form of the predictions which are one of:
- 'probability_vector': Prediction probability associated with each
class as a vector. The probability of the first class (sorted
alphanumerically by name of the class in the training set) is in
position 0 of the vector, the second in position 1 and so on.
- 'class': Class prediction. For multi-class classification, this
returns the class with maximum probability.
Returns
-------
out : SArray
Expand All @@ -177,7 +183,7 @@ def predict(self, dataset):
"""
m = self.__proxy__['classifier']
f = _BOW_FEATURE_EXTRACTOR
return m.predict(f(dataset))
return m.predict(f(dataset), output_type=output_type)

def classify(self, dataset):
"""
Expand Down Expand Up @@ -237,12 +243,11 @@ def _get_summary_struct(self):

def __repr__(self):
width = 32
key_str = "{:<{}}: {}"
(sections, section_titles) = self._get_summary_struct()
out = _toolkit_repr_print(self, sections, section_titles, width=width)
return out

def evaluate(self, dataset, metric = 'auto', **kwargs):
def evaluate(self, dataset, metric='auto', **kwargs):
"""
Evaluate the model by making predictions of target values and comparing
these to actual values.
Expand Down Expand Up @@ -281,22 +286,44 @@ def evaluate(self, dataset, metric = 'auto', **kwargs):
create, predict, classify
"""
target = self.__proxy__['target']
m = self.__proxy__['classifier']
f = _BOW_FEATURE_EXTRACTOR
test = f(dataset)
return m.evaluate(test, **kwargs)
return m.evaluate(test, metric, **kwargs)

def summary(self):
"""
Get a summary for the underlying classifier.
"""
return self.__proxy__['classifier'].summary()

def export_coreml(self, filename):
"""
Export the model in Core ML format.
Parameters
----------
filename: str
A valid filename where the model can be saved.
Examples
--------
>>> model.export_coreml("MyModel.mlmodel")
"""
from turicreate.extensions import _logistic_classifier_export_as_model_asset
from turicreate.toolkits import _coreml_utils

display_name = 'sentence classifier'
short_description = _coreml_utils._mlmodel_short_description(display_name)
context = {'class': self.__class__.__name__,
'version': _tc.__version__,
'short_description': short_description}
model = self.__proxy__['classifier'].__proxy__
_logistic_classifier_export_as_model_asset(model, filename, context)


def _get_str_columns(sf):
"""
Returns a list of names of columns that are string type.
"""
return [name for name in sf.column_names() if sf[name].dtype == str]

0 comments on commit 40e81a3

Please sign in to comment.