PredictionIO's DASE architecture brings the separation-of-concerns design principle to predictive engine development. DASE stands for the following components of an engine:
- Data - includes Data Source and Data Preparator
- Algorithm(s)
- Serving
- Evaluator
Let's look at the code and see how you can customize the engine you built from the E-Commerce Recommendation Engine Template.
The Engine Design
As you can see from the Quick Start, MyECommerceRecommendation takes a JSON prediction query, e.g. { "userEntityId": "u1", "number": 4 }
, and return a JSON predicted result. The Query
class defines the format of such query:
1 2 3 4 5 6 7 8 9 10 11 12 | public class Query implements Serializable { private final String userEntityId; private final int number; private final Set<String> categories; private final Set<String> whitelist; private final Set<String> blacklist; ... } |
The PredictedResult
and ItemScore
classes define the format of predicted result, such as
1 2 3 4 5 6 | {"itemScores":[ {"itemEntityId":22,"score":4.07}, {"itemEntityId":62,"score":4.05}, {"itemEntityId":75,"score":4.04}, {"itemEntityId":68,"score":3.81} ]} |
with:
1 2 3 4 5 6 7 8 9 10 11 12 | public class PredictedResult implements Serializable { private final List<ItemScore> itemScores; ... } public class ItemScore implements Serializable, Comparable<ItemScore> { private final String itemEntityId; private final double score; ... } |
Finally, RecommendationEngine
is the Engine Factory class that defines the components this engine will use: Data Source, Data Preparator, Algorithm(s) and Serving components.
1 2 3 4 5 6 7 8 9 10 11 12 | public class RecommendationEngine extends EngineFactory { @Override public BaseEngine<EmptyParams, Query, PredictedResult, Object> apply() { return new Engine<>( DataSource.class, Preparator.class, Collections.<String, Class<? extends BaseAlgorithm<PreparedData, ?, Query, PredictedResult>>>singletonMap("algo", Algorithm.class), Serving.class ); } } |
Spark MLlib
The PredictionIO E-Commerce Recommendation Engine Template integrates Spark's MLlib ALS algorithm under the DASE architecture. We will take a closer look at the DASE code below.
The MLlib ALS algorithm takes training data of RDD type, i.e. RDD[Rating]
and train a model, which is a MatrixFactorizationModel
object.
You can visit here to learn more about MLlib's ALS collaborative filtering algorithm.
Data
In the DASE architecture, data is prepared by 2 components sequentially: DataSource and DataPreparator. They take data from the data store and prepare them for Algorithm.
Data Source
In DataSource class, the readTraining
method reads and selects data from the Event Store. It returns TrainingData
.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | public TrainingData readTraining(SparkContext sc) { // create a JavaPairRDD of (entityID, User) JavaPairRDD<String,User> usersRDD = PJavaEventStore.aggregateProperties(...) // create a JavaPairRDD of (entityID, Item) JavaPairRDD<String, Item> itemsRDD = PJavaEventStore.aggregateProperties(...) // find all view events JavaRDD<UserItemEvent> viewEventsRDD = PJavaEventStore.find(...) // find all buy events JavaRDD<UserItemEvent> buyEventsRDD = PJavaEventStore.find(...) return new TrainingData(usersRDD, itemsRDD, viewEventsRDD, buyEventsRDD); } |
PredictionIO automatically loads the parameters of datasource specified in MyECommerceRecommendation/engine.json, including appName.
In engine.json:
1 2 3 4 5 6 7 8 9 | { ... "datasource": { "params" : { "appName": "MyApp1" } }, ... } |
In readTraining()
, PJavaEventStore
is an object which provides function to access data that is collected by PredictionIO Event Server.
This E-Commerce Recommendation Engine Template requires "user" and "item" entities that are set by events.
PJavaEventStore.aggregateProperties(...)
aggregates properties of the user
and item
that are set, unset, or deleted by special events $set, $unset and $delete. Please refer to Event API for more details of using these events.
The following code aggregates the properties of user
and then map each result to a User
object.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | JavaPairRDD<String,User> usersRDD = PJavaEventStore.aggregateProperties( dsp.getAppName(), "user", OptionHelper.<String>none(), OptionHelper.<DateTime>none(), OptionHelper.<DateTime>none(), OptionHelper.<List<String>>none(), sc) .mapToPair(new PairFunction<Tuple2<String, PropertyMap>, String, User>() { @Override public Tuple2<String, User> call(Tuple2<String, PropertyMap> entityIdProperty) throws Exception { Set<String> keys = JavaConversions$.MODULE$.setAsJavaSet(entityIdProperty._2().keySet()); Map<String, String> properties = new HashMap<>(); for (String key : keys) { properties.put(key, entityIdProperty._2().get(key, String.class)); } User user = new User(entityIdProperty._1(), ImmutableMap.copyOf(properties)); return new Tuple2<>(user.getEntityId(), user); } }); |
In the template, User
object is a placeholder for you to customize and expand.
Similarly, the following code aggregates item
properties and then map each result to an Item
object. By default, this template assumes each item has an optional property categories
, which is a list of String.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | JavaPairRDD<String, Item> itemsRDD = PJavaEventStore.aggregateProperties( dsp.getAppName(), "item", OptionHelper.<String>none(), OptionHelper.<DateTime>none(), OptionHelper.<DateTime>none(), OptionHelper.<List<String>>none(), sc) .mapToPair(new PairFunction<Tuple2<String, PropertyMap>, String, Item>() { @Override public Tuple2<String, Item> call(Tuple2<String, PropertyMap> entityIdProperty) throws Exception { List<String> categories = entityIdProperty._2().getStringList("categories"); Item item = new Item(entityIdProperty._1(), ImmutableSet.copyOf(categories)); return new Tuple2<>(item.getEntityId(), item); } }); |
PJavaEventStore.find(...)
specifies the events that you want to read. In this case, "user view item" and "user buy item" events are read
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | JavaRDD<UserItemEvent> viewEventsRDD = PJavaEventStore.find( dsp.getAppName(), OptionHelper.<String>none(), OptionHelper.<DateTime>none(), OptionHelper.<DateTime>none(), OptionHelper.some("user"), OptionHelper.<String>none(), OptionHelper.some(Collections.singletonList("view")), OptionHelper.<Option<String>>none(), OptionHelper.<Option<String>>none(), sc) .map(new Function<Event, UserItemEvent>() { @Override public UserItemEvent call(Event event) throws Exception { return new UserItemEvent(event.entityId(), event.targetEntityId().get(), event.eventTime().getMillis(), UserItemEventType.VIEW); } }); |
Similarly, we read buy events from Event Server.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | JavaRDD<UserItemEvent> viewEventsRDD = PJavaEventStore.find( dsp.getAppName(), OptionHelper.<String>none(), OptionHelper.<DateTime>none(), OptionHelper.<DateTime>none(), OptionHelper.some("user"), OptionHelper.<String>none(), OptionHelper.some(Collections.singletonList("buy")), OptionHelper.<Option<String>>none(), OptionHelper.<Option<String>>none(), sc) .map(new Function<Event, UserItemEvent>() { @Override public UserItemEvent call(Event event) throws Exception { return new UserItemEvent(event.entityId(), event.targetEntityId().get(), event.eventTime().getMillis(), UserItemEventType.BUY); } }); |
TrainingData
contains Java RDD of User
, Item
, View Event
, and Buy Event
. The class definition of TrainingData
is:
1 2 3 4 5 6 7 8 9 | public class TrainingData implements Serializable, SanityCheck { private final JavaPairRDD<String, User> users; private final JavaPairRDD<String, Item> items; private final JavaRDD<UserItemEvent> viewEvents; private final JavaRDD<UserItemEvent> buyEvents; ... } |
PredictionIO then passes the returned TrainingData
object to Data Preparator.
Data Preparator
In Preparator, the prepare
method takes TrainingData
as its input and performs any necessary feature selection and data processing tasks. At the end, it returns PreparedData
which should contain the data Algorithm needs.
By default, prepare
simply includes the unprocessed TrainingData
in PreparedData
:
1 2 3 4 5 6 7 8 | public class Preparator extends PJavaPreparator<TrainingData, PreparedData> { @Override public PreparedData prepare(SparkContext sc, TrainingData trainingData) { return new PreparedData(trainingData); } } |
PredictionIO passes the returned PreparedData
object to Algorithm's train
function.
Algorithm
In the Algorithm class, the two methods of interest are train
and predict
. train
is responsible for training the predictive model; predict
is responsible for using this model to make a prediction.
Algorithm parameters
The algorithm takes the following parameters, as defined by the AlgorithmParams
class:
1 2 3 4 5 6 7 8 9 10 11 12 | public class AlgorithmParams implements Params{ private final long seed; private final int rank; private final int iteration; private final double lambda; private final String appName; private final List<String> similarItemEvents; private final boolean unseenOnly; private final List<String> seenItemEvents; ... } |
Parameter description:
- appName: Your App name. Events defined by "seenItemEvents" and "similarItemEvents" will be read from this app during
predict
. - unseenOnly: true or false. Set to true if you want to recommend unseen items only. Seen items are defined by seenItemEvents which mean if the user has these events on the items, then it's treated as seen.
- seenItemEvents: A list of user-to-item events which will be treated as seen events. Used when unseenOnly is set to true.
- similarItemEvents: A list of user-item-item events which will be used to find similar items to the items which the user has performed these events on.
- rank: Parameter of the MLlib ALS algorithm. Number of latent features.
- iteration: Parameter of the MLlib ALS algorithm. Number of iterations.
- lambda: Regularization parameter of the MLlib ALS algorithm.
- seed: A random seed of the MLlib ALS algorithm.
train(...)
train
is called when you run pio train. This is where MLlib ALS algorithm, i.e. ALS.trainImplicit()
, is used to train a predictive model. In addition, we also count the number of items being bought for each item as default model which will be used when there is no ALS model available or other useful information about the user is available during predict
.
1 2 3 4 5 6 7 8 9 10 11 | public Model train(SparkContext sc, PreparedData preparedData) { ... MatrixFactorizationModel matrixFactorizationModel = ALS.trainImplicit(JavaRDD.toRDD(ratings), ap.getRank(), ap.getIteration(), ap.getLambda(), -1, 1.0, ap.getSeed()); ... } |
Working with Spark MLlib's ALS.trainImplicit(....)
MLlib ALS algorithm does not support String
user ID and item ID. ALS.trainImplicit
thus also assumes int-only Rating
object. A view event is an implicit event that does not have an explicit rating value. ALS.trainImplicit()
supports implicit preference. If the Rating
has higher rating value, it means higher confidence that the user prefers the item. Hence we can aggregate how many times the user has viewed the item to indicate the confidence level that the user may prefer the item.
Here are the steps to use MLlib ALS algorithm.
- Map user and item string ID of the view event into integer ID, as required by
Rating
. - Filter out the events with invalid user or item ID.
- Use
reduceByKey()
to add up all values for events with the same user-item combination. - Create a
Rating
object using user index, item index, and summed up score.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | JavaRDD<Rating> ratings = data.getViewEvents().mapToPair(new PairFunction<UserItemEvent, Tuple2<Integer, Integer>, Integer>() { @Override public Tuple2<Tuple2<Integer, Integer>, Integer> call(UserItemEvent viewEvent) throws Exception { Integer userIndex = userIndexMap.get(viewEvent.getUser()); Integer itemIndex = itemIndexMap.get(viewEvent.getItem()); return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex, itemIndex), 1); } }).filter(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Boolean>() { @Override public Boolean call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception { return (element != null); } }).reduceByKey(new Function2<Integer, Integer, Integer>() { @Override public Integer call(Integer integer, Integer integer2) throws Exception { return integer + integer2; } }).map(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Rating>() { @Override public Rating call(Tuple2<Tuple2<Integer, Integer>, Integer> userItemCount) throws Exception { return new Rating(userItemCount._1()._1(), userItemCount._1()._2(), userItemCount._2().doubleValue()); } }); |
In addition to RDD[Rating]
, ALS.trainImplicit
takes the following parameters: rank, iterations, lambda and seed.
The values of these parameters are specified in algorithms section of engine.json:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | { ... "algorithms": [ { "name": "als", "params": { "appName": "MyApp1", "unseenOnly": true, "seenItemEvents": ["buy", "view"], "similarItemEvents" : ["view"] "rank": 10, "iteration" : 20, "lambda": 0.01, "seed": 3 } } ] ... } |
The parameters appName
, unseenOnly
, seenItemEvents
and similarItemEvents
are used during predict()
, which will be explained later.
PredictionIO will automatically loads these values into the AlgorithmParams
field of the Algorithm
.
The seed
parameter is used by MLlib ALS algorithm internally to generate random values. Specify a fixed value for the seed
if you want to have deterministic result (For example, when you are testing).
ALS.trainImplicit()
returns a MatrixFactorizationModel
model which contains two RDDs: userFeatures and productFeatures. They correspond to the user X latent features matrix and item X latent features matrix, respectively.
In addition to the latent feature vector, the item properties (e.g. categories) and popular count are also used during predict()
. Hence, we also save these data along with the feature vector.
PredictionIO will store the returned model after training.
predict(...)
predict
is called when you send a JSON query to http://localhost:8000/queries.json. PredictionIO converts the query, such as { "userEntityId": "u1", "number": 4 }
to the Query
class we defined previously.
We can use the user features and item features stored in the model to calculate the scores of items for the user.
This template also supports additional business logic features, such as filtering items by categories, recommending items in the whitelist, excluding items in the blacklist, recommend unseen items only, and exclude unavailable items defined in constraint event.
The predict()
function does the following:
- Get the user feature vector from the model.
- If there is feature vector for the user, recommend top N items based on the user feature and item features.
- If there is no feature vector for the user, use the recent items acted on by the user (defined by
similarItemEvents
parameter) to recommend similar items. - If there is no recent
similarItemEvents
available for the user, popular items are recommended.
Only items which satisfy the following conditions will be recommended. By default, an item will be recommended if:
- it belongs to one of the categories defined in query.
- it is one of the whitelisted items if a whitelist is defined.
- it is not on the blacklist.
- it is available.
PredictionIO passes the returned PredictedResult
object to Serving.
Serving
The serve
method of class Serving
processes predicted result. It is also responsible for combining multiple predicted results into one if you have more than one predictive model. Serving then returns the final predicted result. PredictionIO will convert it to a JSON response automatically.
1 2 3 4 5 6 7 8 | public class Serving extends LJavaServing<Query, PredictedResult> { @Override public PredictedResult serve(Query query, Seq<PredictedResult> predictions) { return predictions.head(); } } |
When you send a JSON query to http://localhost:8000/queries.json, PredictedResult
from all models will be passed to serve
as a sequence, i.e. Seq<PredictedResult>
.
An engine can train multiple models if you specify more than one Algorithm component in
RecommendationEngine
. Since we only have one algorithm, thisSeq
contains one element.