thanks! @ablaom for the quick response.
My problem is specifically with evaluate!
-ing the composed model (for CV).
here is a simplified code example
module Example
using MLJBase
using RDatasets: dataset
using MLJXGBoostInterface
# ---
mutable struct ExampleTransformer <: Static
end
function MLJBase.transform(self::ExampleTransformer, _, X)
selectrows(X, 1:size(X)[1] - 10)
end
# ---
mutable struct ExampleComposed <: ProbabilisticNetworkComposite
transformer :: ExampleTransformer
classifier :: XGBoostClassifier
end
function MLJBase.prefit(self::ExampleComposed, verbosity, X, y)
Xs = source(X)
ys = source(y)
mach1 = machine(:transformer)
X1 = transform(mach1, Xs)
# --------
# here I change the size of the `target`
y1 = node((y -> y[1:end - 10]), ys) # <---
# --------
mach2 = machine(:classifier, X1, y1)
yhat = predict(mach2, X1)
return (;
predict = yhat,
)
end
function test()
iris = dataset("datasets", "iris")
y, X = unpack(iris, ==(:Species))
y = categorical(map((x -> x == "setosa"), y))
m = machine(
ExampleComposed(
ExampleTransformer(),
XGBoostClassifier()
),
X, y)
m |> fit!
# ==================================================
# The error I get here -
# ERROR: DimensionMismatch: Encountered two objects with sizes (15,) and (25,) which needed to match but don't.
# ==================================================
evaluate!(m, measure=auc)
end
end
the error that I get here is -
ERROR: DimensionMismatch: Encountered two objects with sizes (15,) and (25,) which needed to match but don't.
Stacktrace:
[1] check_dimensions
@ ~/.julia/packages/MLJBase/ByFwA/src/utilities.jl:145 [inlined]
[2] _check(measure::AreaUnderCurve, yhat::UnivariateFiniteVector{…}, y::CategoricalArrays.CategoricalVector{…})
@ MLJBase ~/.julia/packages/MLJBase/ByFwA/src/measures/measures.jl:60
[3] Measure
@ ~/.julia/packages/MLJBase/ByFwA/src/measures/measures.jl:132 [inlined]
[4] value
@ ~/.julia/packages/MLJBase/ByFwA/src/measures/measures.jl:202 [inlined]
[5] value
@ ~/.julia/packages/MLJBase/ByFwA/src/measures/measures.jl:196 [inlined]
[6] (::MLJBase.var"#326#332"{…})(m::AreaUnderCurve, op::Function)
@ MLJBase ~/.julia/packages/MLJBase/ByFwA/src/resampling.jl:1237
[7] #4
@ ./generator.jl:36 [inlined]
[8] iterate
@ ./generator.jl:47 [inlined]
[9] collect(itr::Base.Generator{Base.Iterators.Zip{Tuple{…}}, Base.var"#4#5"{MLJBase.var"#326#332"{…}}})
@ Base ./array.jl:834
[10] map(::Function, ::Vector{AreaUnderCurve}, ::Vector{typeof(predict)})
@ Base ./abstractarray.jl:3409
[11] fit_and_extract_on_fold
@ ~/.julia/packages/MLJBase/ByFwA/src/resampling.jl:1230 [inlined]
[12] (::MLJBase.var"#307#308"{MLJBase.var"#fit_and_extract_on_fold#330"{…}, Machine{…}, Int64})(k::Int64)
@ MLJBase ~/.julia/packages/MLJBase/ByFwA/src/resampling.jl:1056
[13] _mapreduce(f::MLJBase.var"#307#308"{…}, op::typeof(vcat), ::IndexLinear, A::UnitRange{…})
@ Base ./reduce.jl:440
[14] _mapreduce_dim
@ ./reducedim.jl:365 [inlined]
[15] mapreduce
@ ./reducedim.jl:357 [inlined]
[16] _evaluate!(func::MLJBase.var"#fit_and_extract_on_fold#330"{…}, mach::Machine{…}, ::CPU1{…}, nfolds::Int64, verbosity::Int64)
@ MLJBase ~/.julia/packages/MLJBase/ByFwA/src/resampling.jl:1055
[17] evaluate!(mach::Machine{…}, resampling::Vector{…}, weights::Nothing, class_weights::Nothing, rows::Nothing, verbosity::Int64, repeats::Int64, measures::Vector{…}, operations::Vector{…}, acceleration::CPU1{…}, force::Bool, logger::Nothing, user_resampling::CV)
@ MLJBase ~/.julia/packages/MLJBase/ByFwA/src/resampling.jl:1259
[18] evaluate!(::Machine{…}, ::CV, ::Nothing, ::Nothing, ::Nothing, ::Int64, ::Int64, ::Vector{…}, ::Vector{…}, ::CPU1{…}, ::Bool, ::Nothing, ::CV)
@ MLJBase ~/.julia/packages/MLJBase/ByFwA/src/resampling.jl:1335
[19] evaluate!(mach::Machine{…}; resampling::CV, measures::Nothing, measure::AreaUnderCurve, weights::Nothing, class_weights::Nothing, operations::Nothing, operation::Nothing, acceleration::CPU1{…}, rows::Nothing, repeats::Int64, force::Bool, check_measure::Bool, verbosity::Int64, logger::Nothing)
@ MLJBase ~/.julia/packages/MLJBase/ByFwA/src/resampling.jl:1015
[20] test()
@ Main.Example /workspaces/blab/src/Example.jl:59
[21] top-level scope
@ REPL[71]:1