-
Notifications
You must be signed in to change notification settings - Fork 32
Open
Labels
featurea feature request or enhancementa feature request or enhancement
Description
I choose dataset mtcars
to make a reproducible example below.
library(xgboost)
#> Warning: 程辑包'xgboost'是用R版本3.6.1 来建造的
library(tidyverse)
#> Registered S3 methods overwritten by 'ggplot2':
#> method from
#> [.quosures rlang
#> c.quosures rlang
#> print.quosures rlang
#> Warning: 程辑包'dplyr'是用R版本3.6.1 来建造的
train_data <- mtcars %>%
rename(y = am)
dtrain <-
xgb.DMatrix(
data = as.matrix(
train_data %>% select(-y)
)
,label = train_data$y
)
xgb_model <- xgb.train(
data=dtrain,
nround=10,
seed = 1,
max_depth = 1,
objective = "binary:logistic",
base_score = mean(train_data$y) # fix uncalibration problem
)
pred_from_model <- predict(xgb_model, newdata = dtrain)
library(sqldf)
#> Warning: 程辑包'sqldf'是用R版本3.6.1 来建造的
#> 载入需要的程辑包:gsubfn
#> Warning: 程辑包'gsubfn'是用R版本3.6.1 来建造的
#> 载入需要的程辑包:proto
#> Warning: 程辑包'proto'是用R版本3.6.1 来建造的
#> 载入需要的程辑包:RSQLite
#> Warning: 程辑包'RSQLite'是用R版本3.6.1 来建造的
library(tidypredict)
#> Warning: 程辑包'tidypredict'是用R版本3.6.1 来建造的
pred_from_tidypredict <-
tidypredict_sql(xgb_model, dbplyr::simulate_dbi()) %>%
paste("select ",.," from mtcars") %>%
# cat
sqldf() %>%
pull
(pred_from_model-pred_from_tidypredict) %>% abs %>% mean
#> [1] 0.04692561
Created on 2019-10-20 by the reprex package (v0.3.0)
Metadata
Metadata
Assignees
Labels
featurea feature request or enhancementa feature request or enhancement