context("PipeOp")

# PO defined in helper_pipeops.R
test_that("PipeOp - General functions", {
  # Test a lot of the standard slots of a PipeOp
  po_1 = PipeOpDebugBasic$new()
  expect_class(po_1, "PipeOpDebugBasic")
  expect_true(po_1$id == "debug.basic")
  expect_false(po_1$is_trained)
  expect_class(po_1$param_set, "ParamSet")
  expect_list(po_1$param_set$values, names = "unique")
  expect_output(print(po_1), "PipeOp")
  expect_equal(po_1$packages, "mlr3pipelines")
  expect_null(po_1$state)
  assert_subset(po_1$tags, mlr_reflections$pipeops$valid_tags)
  expect_error(po_1$predict(list(tsk("iris"))), "has not been trained yet")

  expect_output(expect_equal(po_1$train(list(1)), list(output = 1)), "Training debug.basic")
  expect_equal(po_1$state, list(input = 1))
  expect_true(po_1$is_trained)
  expect_error(po_1$train(tsk("iris")), regexp = "type 'list'")
})


test_that("PipeOp - simple tests with PipeOpScale", {
  p = PipeOpScale$new()
  expect_class(p, "PipeOpScale")
  expect_false(p$is_trained)
  expect_class(p$param_set, "ParamSet")
})

test_that("PipeOp printer", {
  expect_snapshot(print(PipeOpNOP$new()))
  expect_snapshot(print(PipeOpDebugMulti$new(3, 4)))
  expect_snapshot(print(PipeOpDebugMulti$new(100, 0)))
  expect_snapshot(print(PipeOpBranch$new(c("odin", "dva", "tri"))))
  expect_snapshot(print(PipeOpLearner$new(mlr_learners$get("classif.debug"))))
})

test_that("Prevent creation of PipeOps with no channels", {
  expect_class(PipeOp$new("id", input = data.table(name = "input", train = "*", predict = "*"),
    output = data.table(name = "output", train = "*", predict = "*")), "PipeOp")

  expect_error(PipeOp$new("id", input = data.table(name = "input", train = "*", predict = "*")[FALSE],
    output = data.table(name = "output", train = "*", predict = "*")), "input.*at least 1 row")

  expect_error(PipeOp$new("id", input = data.table(name = "input", train = "*", predict = "*"),
    output = data.table(name = "output", train = "*", predict = "*")[FALSE]), "output.*at least 1 row")
})

test_that("Errors occur for inputs", {
  po = PipeOp$new("id", input = data.table(name = "input", train = "*", predict = "*"),
    output = data.table(name = "output", train = "*", predict = "*"))
  expect_error(train_pipeop(po, list(mlr_tasks$get("iris"))), "abstract")
  po$state = list(NULL)
  expect_error(predict_pipeop(po, list(mlr_tasks$get("iris"))), "abstract")
  expect_error({
    po$param_set = ParamSet$new()
  }, "read-only")
})

test_that("Errors during training set $state to NULL", {
  po = PipeOp$new("id", input = data.table(name = "input", train = "*", predict = "*"),
    output = data.table(name = "output", train = "*", predict = "*"))
  expect_null(po$state)
  po$state = list("not_null")
  expect_error(po$train(list(mlr_tasks$get("iris"))), regexp = "abstract")
  expect_null(po$state)  # state is completely reset to NULL
})

test_that("Informative error and warning messages", {

  gr = as_graph(lrn("classif.debug"))

  gr$param_set$values$classif.debug.warning_train = 1
  gr$param_set$values$classif.debug.warning_predict = 1

  # two 'expect_warning', because we want to 'expect' that there is exactly one warning.
  # a function argument for expect_warning that tests exactly this would be a good idea, and has therefore been removed -.-
  expect_no_warning(expect_warning(gr$train(tsk("iris")), "This happened in PipeOp classif.debug's \\$train\\(\\)$"))

  expect_no_warning(suppressWarnings(gr$train(tsk("iris"))))

  expect_no_warning(expect_warning(gr$predict(tsk("iris")), "This happened in PipeOp classif.debug's \\$predict\\(\\)$"))

  expect_no_warning(suppressWarnings(gr$predict(tsk("iris"))))

  gr$param_set$values$classif.debug.warning_train = 0
  gr$param_set$values$classif.debug.warning_predict = 0

  gr$param_set$values$classif.debug.error_train = 1
  expect_error(gr$train(tsk("iris")), "This happened in PipeOp classif.debug's \\$train\\(\\)$")

  gr$param_set$values$classif.debug.error_train = 0
  gr$param_set$values$classif.debug.error_predict = 1
  # Need to first train the Graph for predict to work
  gr$train(tsk("iris"))
  expect_error(gr$predict(tsk("iris")), "This happened in PipeOp classif.debug's \\$predict\\(\\)$")

  potest = R6::R6Class("potest", inherit = PipeOp,
    private = list(
      .train = function(input) {
        self$state = list()
        suppressWarnings(warning("test"))
        list(1)
      },
      .predict = function(input) {
        suppressWarnings(warning("test"))
        list(1)
      }
    )
  )$new(id = "potest", input = data.table(name = "input", train = "*", predict = "*"), output = data.table(name = "input", train = "*", predict = "*"))

  expect_no_warning(potest$train(list(1)))
  expect_no_warning(potest$predict(list(1)))

})

test_that("properties", {
  f = function(properties) {
    PipeOp$new(
      id = "potest",
      input = data.table(name = "input", train = "*", predict = "*"),
      output = data.table(name = "input", train = "*", predict = "*"),
      properties = properties
    )
  }

  expect_error(f("abc"))
  po1 = f("validation")
  expect_equal(po1$properties, "validation")
})

test_that("PipeOp - auto-train untrained PipeOps during predict that have input type NULL", {
  PipeOpTestAutotrain = R6Class("PipeOpTestAutotrain",
    inherit = PipeOp,
    public = list(
      # Add argument innum
      initialize = function(innum = 1, id = "test_autotrain", param_set = ps()) {
        super$initialize(id = id, param_set = param_set,
          input = data.table(name = rep_suffix("input", innum), train = "NULL", predict = "*"),
          output = data.table(name = "output", train = "*", predict = "*")
        )
      }),
    private = list(
      .train = function(inputs) {
        catf("Training %s", self$id)
        self$state = list(length(inputs))
        inputs[1L]
      },
      .predict = function(inputs) {
        catf("Predicting %s", self$id)
        inputs[1L]
      }
    )
  )

  # Normal input
  op = PipeOpTestAutotrain$new()
  predict_out = op$predict(list("test"))
  expect_equal(op$state, list(1))
  expect_equal(predict_out, list(output = "test"))

  # single Multiplicity input
  op$state = NULL  # reset PipeOp
  predict_out = op$predict(list(Multiplicity("test", "test")))
  expect_equal(op$state, Multiplicity(list(1), list(1)))
  expect_equal(predict_out, list(output = Multiplicity("test", "test")))

  # nested Multiplicity input
  op$state = NULL  # reset PipeOp
  predict_out = op$predict(list(Multiplicity("test", Multiplicity("test", "test"))))
  expect_equal(op$state, Multiplicity(list(1), Multiplicity(list(1), list(1))))
  expect_equal(predict_out, list(output = Multiplicity("test", Multiplicity("test", "test"))))

  # Tests with PipeOp that has multiple input channel
  op = PipeOpTestAutotrain$new(innum = 2)

  # Normal input
  predict_out = op$predict(list("test", "test"))
  expect_equal(op$state, list(2))
  expect_equal(predict_out, list(output = c("test")))

  # single Multiplicity input
  op$state = NULL  # reset PipeOp
  predict_out = op$predict(list(Multiplicity("test", "test"), Multiplicity("test", "test")))
  expect_equal(op$state, Multiplicity(list(2), list(2)))
  expect_equal(predict_out, list(output = Multiplicity("test", "test")))

  # nested Multiplicity input
  op$state = NULL  # reset PipeOp
  predict_out = op$predict(
    list(Multiplicity("test", Multiplicity("test", "test")), Multiplicity("test", Multiplicity("test", "test")))
  )
  expect_equal(op$state, Multiplicity(list(2), Multiplicity(list(2), list(2))))
  expect_equal(predict_out, list(output = Multiplicity("test", Multiplicity("test", "test"))))

  # Simple test that pseudo-Multiplicity-aware PipeOp (having "[NULL]" as input type) works
  op = PipeOpTestAutotrain$new(innum = 1)
  op$input$train = "[NULL]"
  op$input$predict = "[*]"
  op$output$train = "[*]"
  op$output$predict = "[*]"
  predict_out = op$predict(list(Multiplicity("test", "test")))
  expect_equal(op$state, list(1))
  expect_equal(predict_out, list(output = Multiplicity("test", "test")))

  # Test with real PipeOp: PipeOpClassifAvg
  op = po("classifavg")
  task = tsk("iris")
  # Get a PredictionClassif object
  learner = lrn("classif.featureless")
  learner$train(task)
  predict_out = learner$predict(task)
  expect_no_error(op$predict(list(predict_out)))

})
