Description
The NaiveBayesTextClassifierPredict function uses the model table generated by the NaiveBayesTextClassifierTrainer function to predict outcomes for test data.
Usage
td_naivebayes_textclassifier_predict_sqle ( object = NULL, newdata = NULL, input.token.column = NULL, doc.id.columns = NULL, model.type = "MULTINOMIAL", top.k = NULL, model.token.column = NULL, model.category.column = NULL, model.prob.column = NULL, newdata.partition.column = NULL) ## S3 method for class 'td_naivebayes_textclassifier_mle' predict( object = NULL, newdata = NULL, input.token.column = NULL, doc.id.columns = NULL, model.type = "MULTINOMIAL", top.k = NULL, model.token.column = NULL, model.category.column = NULL, model.prob.column = NULL, newdata.partition.column = NULL)
Arguments
object |
Required Argument. |
newdata |
Required Argument. |
newdata.partition.column |
Partition By columns for newdata. |
input.token.column |
Required Argument. |
doc.id.columns |
Required Argument. |
model.type |
Optional Argument. |
top.k |
Optional Argument. |
model.token.column |
Optional Argument. |
model.category.column |
Optional Argument. |
model.prob.column |
Optional Argument. |
Value
Function returns an object of class "td_naivebayes_textclassifier_predict_sqle" which is a named list containing Teradata tbl object. Named list member can be referenced directly with the "$" operator using name: result
Examples
# Get the current context/connection con <- td_get_context()$connection # Load example data. loadExampleData("naivebayes_textclassifier_predict_example", "token_table","complaints_tokens_test") # Create remote tibble objects. token_table <- tbl(con, "token_table") complaints_tokens_test <- tbl(con,"complaints_tokens_test") # Example - #Create the model textclassifier_out <- td_naivebayes_textclassifier_mle(data = token_table, data.partition.column = c("category"), token.column = "token", doc.id.columns = c("doc_id"), doc.category.column = "category", model.type = "Bernoulli" ) # Predict the output predict_out <- td_naivebayes_textclassifier_predict_sqle(newdata = complaints_tokens_test, object = textclassifier_out, newdata.partition.column = "doc_id", input.token.column = "token", doc.id.columns = c("doc_id"), model.type = "Bernoulli", top.k = 1 ) # Alternatively use S3 predict method to find the predictions. predict_result <- predict(textclassifier_out, newdata = complaints_tokens_test, newdata.partition.column = "doc_id", input.token.column = "token", doc.id.columns = c("doc_id"), model.type = "Bernoulli", top.k = 1)