package com.gyee.impala.controller.diagnose; import com.alibaba.fastjson.JSONObject; import com.gyee.impala.common.result.JsonResult; import com.gyee.impala.common.result.ResultCode; import com.gyee.impala.model.master.diagnose.Diagnosetrainhistory; import com.gyee.impala.model.master.diagnose.TrainInfo; import com.gyee.impala.service.custom.ShiroService; import com.gyee.impala.service.master.diagnose.TrainFileModeService; import org.apache.ibatis.logging.Log; import org.apache.ibatis.logging.LogFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.web.bind.annotation.*; import org.springframework.web.multipart.MultipartFile; import javax.annotation.Resource; import javax.servlet.http.HttpServletRequest; import java.util.List; /** * 文件方式训练模型 */ @CrossOrigin @RestController @RequestMapping("/api/filemode") public class TrainFileModeController { protected Log log = LogFactory.getLog(getClass()); /** * 线程池 */ @Resource private ThreadPoolTaskExecutor taskExecutor; @Autowired private TrainFileModeService trainFileModeService; private static final Object locker = new Object(); private String name1; private String forecastLabel1; private String[] inputLabel1; private String host1; private MultipartFile file1; /** * 文件模式训练接口 * @param request * @param name * @param forecastLabel * @param inputLabel * @param host * @param file * @return */ @PostMapping("/trainfile") @ResponseBody public JSONObject getTrainfile(HttpServletRequest request, String name, String forecastLabel, String[] inputLabel, String host, MultipartFile file) { if (!trainFileModeService.isComplete()) { return JsonResult.error(4000, "命令正在执行..."); } if (file.isEmpty()) { return JsonResult.error(ResultCode.ERROR_FILE_NO); } try { synchronized (locker) { name1 = name; forecastLabel1 = forecastLabel; inputLabel1 = inputLabel; host1 = host; file1 = file; taskExecutor.submit(this::execute); } return JsonResult.success(ResultCode.SUCCESS); } catch (Exception e) { return JsonResult.error(ResultCode.ERROR_DATA_FILE); } } /** * 调用执行脚本 */ private void execute() { trainFileModeService.exec(name1, forecastLabel1, inputLabel1, host1, file1); } /** * 生产控制台信息 * @param trainInfo * @return */ @PostMapping("/addtrainInfo") public JSONObject addProducer(@RequestBody String trainInfo) { try { TrainInfo t = JSONObject.parseObject(trainInfo, TrainInfo.class); trainFileModeService.produce(t); return JsonResult.success(ResultCode.SUCCESS); } catch (Exception e) { return JsonResult.error(ResultCode.ERROR); } } /** * 消费控制台信息 * @return */ @GetMapping("/gettrainInfo") @ResponseBody public JSONObject getConsume() { try { List<TrainInfo> list = trainFileModeService.consume(); System.out.println(JsonResult.successData(ResultCode.SUCCESS, list)); return JsonResult.successData(ResultCode.SUCCESS, list); } catch (Exception e) { return JsonResult.error(ResultCode.ERROR); } } /** * 添加训练结果 * * @param history * @return */ @PostMapping("/putHistory") public JSONObject putDiagnosetrainhistory(@RequestBody String history) { try { log.warn(history); trainFileModeService.putDiagnosetrainhistory(history); return JsonResult.success(ResultCode.SUCCESS); } catch (Exception e) { log.error("请求错误", e); return JsonResult.error(ResultCode.ERROR); } } /** * 获取当前训练结果 * @return */ @GetMapping("/getHistory") public JSONObject getHistory() { try { Diagnosetrainhistory d = trainFileModeService.consumeHistory(); return JsonResult.successData(ResultCode.SUCCESS, d); } catch (Exception e) { log.error("请求错误", e); return JsonResult.error(ResultCode.ERROR); } } /** * 编辑最终 * * @param history * @return */ @PostMapping("/editHistory") public JSONObject editDiagnosetrainhistory(@RequestBody String history) { try { trainFileModeService.editDiagnosetrainhistory(history); return JsonResult.success(ResultCode.SUCCESS); } catch (Exception e) { return JsonResult.error(ResultCode.ERROR_DATA); } } /** * @return */ @GetMapping("/getHistoryList") public JSONObject getDiagnosetrainhistoryList() { try { List<Diagnosetrainhistory> list = trainFileModeService.getDiagnosetrainhistoryList(); return JsonResult.successData(ResultCode.SUCCESS, list); } catch (Exception e) { return JsonResult.error(ResultCode.ERROR); } } /** * 预测评估 * * @param jsonObject * @return */ @PostMapping("/forecasts") public JSONObject forecasts(@RequestBody JSONObject jsonObject) { try { log.warn("预估请求数据:" + jsonObject.toJSONString()); String resultvalue = trainFileModeService.forecasts(jsonObject); return JsonResult.successData(ResultCode.SUCCESS, resultvalue); } catch (Exception e) { return JsonResult.error(ResultCode.ERROR); } } }