Parcourir la source

Merge branch 'master' of http://124.70.43.205:3000/xieshengjie/sis-background

# Conflicts:
#	gyee-sample-impala/src/main/java/com/gyee/impala/service/master/diagnose/TrainDataModeService.java
chenminghua il y a 3 ans
Parent
commit
805310cd06

+ 69 - 27
gyee-sample-impala/src/main/java/com/gyee/impala/controller/diagnose/TrainFileModeController.java

@@ -4,12 +4,12 @@ package com.gyee.impala.controller.diagnose;
 import com.alibaba.fastjson.JSONObject;
 import com.gyee.impala.common.result.JsonResult;
 import com.gyee.impala.common.result.ResultCode;
-import com.gyee.impala.model.custom.TokenUser;
 import com.gyee.impala.model.master.diagnose.Diagnosetrainhistory;
 import com.gyee.impala.model.master.diagnose.TrainInfo;
-import com.gyee.impala.service.custom.SftpFileService;
 import com.gyee.impala.service.custom.ShiroService;
 import com.gyee.impala.service.master.diagnose.TrainFileModeService;
+import org.apache.ibatis.logging.Log;
+import org.apache.ibatis.logging.LogFactory;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
 import org.springframework.web.bind.annotation.*;
@@ -28,9 +28,8 @@ import java.util.List;
 @RequestMapping("/api/filemode")
 public class TrainFileModeController {
 
-    @Autowired
-    private ShiroService shiroService;
 
+    protected Log log = LogFactory.getLog(getClass());
 
     /**
      * 线程池
@@ -52,14 +51,19 @@ public class TrainFileModeController {
 
 
     /**
-     *
+     * 文件模式训练接口
+     * @param request
+     * @param name
+     * @param forecastLabel
+     * @param inputLabel
+     * @param host
+     * @param file
+     * @return
      */
     @PostMapping("/trainfile")
     @ResponseBody
     public JSONObject getTrainfile(HttpServletRequest request, String name, String forecastLabel, String[] inputLabel, String host, MultipartFile file) {
 
-//        String token = request.getHeader("token");
-//        TokenUser user = shiroService.findToken(token);
         if (!trainFileModeService.isComplete()) {
             return JsonResult.error(4000, "命令正在执行...");
         }
@@ -71,10 +75,10 @@ public class TrainFileModeController {
         try {
             synchronized (locker) {
                 name1 = name;
-                forecastLabel1= forecastLabel;
-                inputLabel1= inputLabel;
-                host1= host;
-                file1= file;
+                forecastLabel1 = forecastLabel;
+                inputLabel1 = inputLabel;
+                host1 = host;
+                file1 = file;
                 taskExecutor.submit(this::execute);
             }
 
@@ -85,11 +89,18 @@ public class TrainFileModeController {
     }
 
 
-
+    /**
+     * 调用执行脚本
+     */
     private void execute() {
         trainFileModeService.exec(name1, forecastLabel1, inputLabel1, host1, file1);
     }
 
+    /**
+     * 生产控制台信息
+     * @param trainInfo
+     * @return
+     */
     @PostMapping("/addtrainInfo")
     public JSONObject addProducer(@RequestBody String trainInfo) {
 
@@ -105,6 +116,10 @@ public class TrainFileModeController {
     }
 
 
+    /**
+     * 消费控制台信息
+     * @return
+     */
     @GetMapping("/gettrainInfo")
     @ResponseBody
     public JSONObject getConsume() {
@@ -119,32 +134,35 @@ public class TrainFileModeController {
 
 
     /**
-     * 训练结果
+     * 添加训练结果
+     *
      * @param history
      * @return
      */
     @PostMapping("/putHistory")
-    public JSONObject putDiagnosetrainhistory(@RequestBody  String history){
+    public JSONObject putDiagnosetrainhistory(@RequestBody String history) {
         try {
+            log.warn(history);
             trainFileModeService.putDiagnosetrainhistory(history);
             return JsonResult.success(ResultCode.SUCCESS);
         } catch (Exception e) {
+            log.error("请求错误", e);
             return JsonResult.error(ResultCode.ERROR);
         }
     }
 
 
     /**
-     * 获取最终结果
-     * @param history
+     * 获取当前训练结果
      * @return
      */
     @GetMapping("/getHistory")
-    public JSONObject getDiagnosetrainhistory(String history){
+    public JSONObject getHistory() {
         try {
-            Diagnosetrainhistory d = trainFileModeService.getDiagnosetrainhistory(history);
+            Diagnosetrainhistory d = trainFileModeService.consumeHistory();
             return JsonResult.successData(ResultCode.SUCCESS, d);
         } catch (Exception e) {
+            log.error("请求错误", e);
             return JsonResult.error(ResultCode.ERROR);
         }
     }
@@ -152,28 +170,52 @@ public class TrainFileModeController {
 
     /**
      * 编辑最终
+     *
      * @param history
      * @return
      */
     @PostMapping("/editHistory")
-    public JSONObject editDiagnosetrainhistory(String history){
+    public JSONObject editDiagnosetrainhistory(@RequestBody String history) {
         try {
 
-            int code = trainFileModeService.editDiagnosetrainhistory(history);
-            if(code >0){
-                return JsonResult.success(ResultCode.SUCCESS);
-            }else {
-                return JsonResult.error(ResultCode.ERROR);
-            }
-
+            trainFileModeService.editDiagnosetrainhistory(history);
+            return JsonResult.success(ResultCode.SUCCESS);
         } catch (Exception e) {
-            return JsonResult.error(ResultCode.ERROR);
+            return JsonResult.error(ResultCode.ERROR_DATA);
         }
     }
 
 
+    /**
+     * @return
+     */
+    @GetMapping("/getHistoryList")
+    public JSONObject getDiagnosetrainhistoryList() {
+        try {
+            List<Diagnosetrainhistory> list = trainFileModeService.getDiagnosetrainhistoryList();
+            return JsonResult.successData(ResultCode.SUCCESS, list);
+        } catch (Exception e) {
+            return JsonResult.error(ResultCode.ERROR);
+        }
+    }
 
+    /**
+     * 预测评估
+     *
+     * @param jsonObject
+     * @return
+     */
+    @PostMapping("/forecasts")
+    public JSONObject forecasts(@RequestBody JSONObject jsonObject) {
+        try {
+            log.warn("预估请求数据:" + jsonObject.toJSONString());
+            String resultvalue = trainFileModeService.forecasts(jsonObject);
+            return JsonResult.successData(ResultCode.SUCCESS, resultvalue);
+        } catch (Exception e) {
+            return JsonResult.error(ResultCode.ERROR);
+        }
 
+    }
 
 
 }

+ 2 - 0
gyee-sample-impala/src/main/java/com/gyee/impala/model/master/diagnose/TrainInfo.java

@@ -7,4 +7,6 @@ public class TrainInfo {
     private String log;
     private String time;
     private String complete;
+    private String modelname;
+
 }

+ 19 - 3
gyee-sample-impala/src/main/java/com/gyee/impala/service/impl/master/diagnose/DiagnosetrainhistoryServiceImpl.java

@@ -6,7 +6,7 @@ import com.gyee.impala.common.base.ExcludeQueryWrapper;
 import com.gyee.impala.common.config.datasource.KuduDataSourceConfig;
 import com.gyee.impala.common.exception.CustomException;
 import com.gyee.impala.common.result.ResultCode;
-import com.gyee.impala.common.util.DateUtil;
+import com.gyee.impala.common.spring.InitialRunner;
 import com.gyee.impala.common.util.SnowFlakeUtil;
 import com.gyee.impala.mapper.master.diagnose.DiagnosetrainhistoryMapper;
 import com.gyee.impala.model.master.diagnose.AlgorithmType;
@@ -24,6 +24,9 @@ public class DiagnosetrainhistoryServiceImpl extends ServiceImpl<Diagnosetrainhi
     @Autowired
     private KuduDataSourceConfig kuduConfig;
 
+    @Autowired
+    private InitialRunner initialRunner;
+
     @Override
     public List<Diagnosetrainhistory> getAll( AlgorithmType type) {
         ExcludeQueryWrapper<Diagnosetrainhistory> wrapper = new ExcludeQueryWrapper<>();
@@ -71,9 +74,22 @@ public class DiagnosetrainhistoryServiceImpl extends ServiceImpl<Diagnosetrainhi
     }
 
     @Override
-    public int editDiagnosetrainhistory(Diagnosetrainhistory history) {
+    public void editDiagnosetrainhistory(Diagnosetrainhistory history) {
+        try{
+            baseMapper.updateById(history);
+            initialRunner.cacheKnowCategory();
+        } catch (Exception e){
+            log.error(e.getMessage());
+            throw new CustomException(ResultCode.ERROR_DATA);
+        }
+    }
+
+    @Override
+    public List<Diagnosetrainhistory> getList() {
+        QueryWrapper<Diagnosetrainhistory> wrapper = new QueryWrapper<>();
+//        wrapper.eq("model",model);
         try{
-            return baseMapper.updateById(history);
+            return baseMapper.selectList(wrapper);
         } catch (Exception e){
             log.error(e.getMessage());
             throw new CustomException(ResultCode.ERROR_DATA);

+ 8 - 1
gyee-sample-impala/src/main/java/com/gyee/impala/service/master/diagnose/DiagnosetrainhistoryService.java

@@ -39,6 +39,13 @@ public interface DiagnosetrainhistoryService extends IService<Diagnosetrainhisto
      * 编辑模型
      * @param history
      */
-    int editDiagnosetrainhistory(Diagnosetrainhistory history);
+    void editDiagnosetrainhistory(Diagnosetrainhistory history);
+
+
+    /**
+     * 获取训练历史结果
+     * @return
+     */
+    List<Diagnosetrainhistory> getList();
 
 }

+ 0 - 2
gyee-sample-impala/src/main/java/com/gyee/impala/service/master/diagnose/TrainDataModeService.java

@@ -17,9 +17,7 @@ public class TrainDataModeService {
 
 
     public List<ListTablesResponse.TableInfo> getListTables() throws Exception {
-
         List<ListTablesResponse.TableInfo> tables = kuduConfig.kuduClient.getTablesList().getTableInfosList();
-
         return tables;
     }
 

+ 109 - 24
gyee-sample-impala/src/main/java/com/gyee/impala/service/master/diagnose/TrainFileModeService.java

@@ -4,10 +4,14 @@ package com.gyee.impala.service.master.diagnose;
 import com.alibaba.fastjson.JSONObject;
 import com.gyee.impala.common.config.GyeeConfig;
 import com.gyee.impala.common.config.jsch.JSchConfig;
+import com.gyee.impala.common.util.DateUtil;
+import com.gyee.impala.common.util.FileUtil;
 import com.gyee.impala.common.util.JudeSystem;
 import com.gyee.impala.model.master.diagnose.Diagnosetrainhistory;
 import com.gyee.impala.model.master.diagnose.TrainInfo;
 import com.gyee.impala.service.custom.SftpFileService;
+import org.apache.ibatis.logging.Log;
+import org.apache.ibatis.logging.LogFactory;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.stereotype.Service;
 import org.springframework.web.multipart.MultipartFile;
@@ -22,6 +26,7 @@ import java.util.concurrent.LinkedBlockingQueue;
 public class TrainFileModeService {
 
 
+    protected Log log = LogFactory.getLog(getClass());
     @Autowired
     private SftpFileService fileService;
 
@@ -47,14 +52,14 @@ public class TrainFileModeService {
         return isComplete;
     }
 
+    //控制太输出信息队列
     BlockingQueue<TrainInfo> infoQueue = new LinkedBlockingQueue<TrainInfo>();
+    //零时保存模型队列
     BlockingQueue<Diagnosetrainhistory> historyQueue = new LinkedBlockingQueue<Diagnosetrainhistory>();
 
-    Map<String, BlockingQueue<TrainInfo>> mapConsole = new HashMap<>();
-
-
     public void exec(String name, String forecastLabel, String[] inputLabel, String host, MultipartFile file) {
 
+        //判断当前是否有模型在训练
         if (!isComplete) {
             return;
         }
@@ -64,22 +69,20 @@ public class TrainFileModeService {
             String type = file.getOriginalFilename().substring(file.getOriginalFilename().lastIndexOf("."));
             /**上传文件**/
             fileService.uploadFile(file.getOriginalFilename(), file.getInputStream(), "10.155.32.14");
-
-
             //文件路径
             String filePath = config.getPath() + file.getOriginalFilename();
 
-            System.out.println(new Date() + "开始执行脚本...");
+            System.out.println(DateUtil.getCurrentDate() + "开始执行脚本" + (isComplete) + "...");
             Process p;
+            //组装调用python脚本命令
             String cmdPath = gyeeConfig.getDiagnosePath();
             String inst = JudeSystem.isWindows() ? "cmd" : "/bin/sh";
             String c = JudeSystem.isWindows() ? "/c" : "-c";
-//            String[] cmd = {"/bin/sh", "-c", "python " + cmdPath + name + ".py " + filePath};
             String[] cmd = {inst, c, "python " + cmdPath + name + ".py " + filePath};
-
-            Thread.sleep(3000);
             System.out.println(cmd[0] + " " + cmd[1] + " " + cmd[2]);
+            //执行命令
             p = Runtime.getRuntime().exec(cmd);
+            //监听脚本输出信息
             BufferedReader bri = new BufferedReader(new InputStreamReader(p.getInputStream()));
             BufferedReader bre = new BufferedReader(new InputStreamReader(p.getErrorStream()));
             String si = null, se = null;
@@ -89,8 +92,6 @@ public class TrainFileModeService {
                 }
                 if (se != null) {
                     System.err.println(se);
-                    TrainInfo t = JSONObject.parseObject(se, TrainInfo.class);
-                    produce(t);
                 }
             }
             p.waitFor();
@@ -101,21 +102,27 @@ public class TrainFileModeService {
             isComplete = true;
         }
 
-        System.out.println(new Date() + "脚本执行结束...");
+        System.out.println(DateUtil.getCurrentDate() + "脚本执行结束" + (isComplete) + "...");
     }
 
 
-    // 控制台信息
+    /**
+     * 控制台信息
+     * @param trainInfo
+     * @throws Exception
+     */
     public void produce(TrainInfo trainInfo) throws Exception {
         // put 控制台信息到队列中
-//        trainInfo.setComplete(trainInfo.getComplete() ? trainInfo.isComplete() : isComplete);
-
-        System.out.println("add log =["+trainInfo.getLog()+"]  time = ["+trainInfo.getTime()+"] complete = ["+isComplete+"]");
+        System.out.println("add log =[" + trainInfo.getLog() + "]  time = [" + trainInfo.getTime() + "] complete = [" + trainInfo.getComplete() + "]");
         infoQueue.put(trainInfo);
     }
 
 
-    // 输出控制台信息
+    /**
+     * 输出控制台信息
+     * @return list
+     * @throws Exception
+     */
     public List<TrainInfo> consume() throws Exception {
         // take方法取出一条记录
         List<TrainInfo> list = new ArrayList<>();
@@ -136,26 +143,104 @@ public class TrainFileModeService {
      */
     public void putDiagnosetrainhistory(String history) throws Exception {
 
-        Diagnosetrainhistory d = JSONObject.parseObject(history, Diagnosetrainhistory.class);
+        String body = JSONObject.toJSONString(history);
+        Object historys = JSONObject.parse(body);
+        Diagnosetrainhistory d = JSONObject.parseObject(historys.toString(), Diagnosetrainhistory.class);
         d.setEnable(true);
         diagnosetrainhistoryService.insertItem(d);
         Diagnosetrainhistory dbhistory = diagnosetrainhistoryService.getHistoryByModel(d.getModel());
         historyQueue.put(dbhistory);
+
     }
 
-    public Diagnosetrainhistory getDiagnosetrainhistory(String id) throws Exception {
-        Diagnosetrainhistory history = historyQueue.take();
-        return history;
+    /**
+     * 训练完成获取实时模型结果
+     *
+     * @return
+     * @throws Exception
+     */
+    public Diagnosetrainhistory consumeHistory() throws Exception {
+        Diagnosetrainhistory d = null;
+        try {
+            if (historyQueue.size() > 0) {
+                d = historyQueue.take();
+            }
+        } catch (InterruptedException e) {
+            e.printStackTrace();
+        }
+        return d;
     }
 
 
-    public int editDiagnosetrainhistory(String history) {
+    /**
+     * 编辑模型
+     *
+     * @param history
+     * @throws Exception
+     */
+    public void editDiagnosetrainhistory(String history) throws Exception {
         Diagnosetrainhistory update = JSONObject.parseObject(history, Diagnosetrainhistory.class);
+        diagnosetrainhistoryService.editDiagnosetrainhistory(update);
+    }
 
-        int code  = diagnosetrainhistoryService.editDiagnosetrainhistory(update);
 
-        return  code;
+    /**
+     * 查询历史模型
+     *
+     * @return
+     */
+    public List<Diagnosetrainhistory> getDiagnosetrainhistoryList() {
+        List<Diagnosetrainhistory> list = diagnosetrainhistoryService.getList();
+        return list;
     }
 
 
+    /**
+     * 预估
+     *
+     * @param jsonObject
+     * @return
+     */
+    public String forecasts(JSONObject jsonObject) {
+
+        String resultvalue = null;
+        try {
+            //解析请求数据
+            String name = jsonObject.get("name").toString();
+            String filename = System.currentTimeMillis() + ".json";
+            //文件路径
+            String filePath = config.getPath() + filename;
+            fileService.uploadFile(filename, FileUtil.convertStringToInputStream(jsonObject.toJSONString()), "10.155.32.14");
+            resultvalue = "";
+            //组装调用脚本命令
+            String cmdPath = gyeeConfig.getDiagnosePath();
+            String inst = JudeSystem.isWindows() ? "cmd" : "/bin/sh";
+            String c = JudeSystem.isWindows() ? "/c" : "-c";
+            String[] cmd = {inst, c, "python " + cmdPath + name + ".py " + filePath};
+            Process p;
+            System.out.println(cmd[0] + " " + cmd[1] + " " + cmd[2]);
+            //执行脚本
+            p = Runtime.getRuntime().exec(cmd);
+            //脚本输出信息
+            BufferedReader bri = new BufferedReader(new InputStreamReader(p.getInputStream()));
+            BufferedReader bre = new BufferedReader(new InputStreamReader(p.getErrorStream()));
+            String si = null, se = null;
+            while ((si = bri.readLine()) != null || (se = bre.readLine()) != null) {
+                if (si != null) {
+                    System.out.println("预估返回信息:" + si);
+                    if (si.startsWith("resultvalue:")) {
+                        resultvalue = si.replaceAll("resultvalue:", "");
+                    }
+                }
+                if (se != null) {
+                    System.err.println("预估返回错误信息:" + se);
+                }
+            }
+            p.waitFor();
+        } catch (Exception e) {
+            e.printStackTrace();
+        }
+        return resultvalue.trim();
+
+    }
 }