Forráskód Böngészése

模型预估接口整理

wangchangsheng 2 éve
szülő
commit
b53c47be11

+ 27 - 10
gyee-sample-impala/src/main/java/com/gyee/impala/controller/diagnose/TrainFileModeController.java

@@ -31,10 +31,6 @@ public class TrainFileModeController {
 
     protected Log log = LogFactory.getLog(getClass());
 
-
-    @Autowired
-    private ShiroService shiroService;
-
     /**
      * 线程池
      */
@@ -55,14 +51,19 @@ public class TrainFileModeController {
 
 
     /**
-     *
+     * 文件模式训练接口
+     * @param request
+     * @param name
+     * @param forecastLabel
+     * @param inputLabel
+     * @param host
+     * @param file
+     * @return
      */
     @PostMapping("/trainfile")
     @ResponseBody
     public JSONObject getTrainfile(HttpServletRequest request, String name, String forecastLabel, String[] inputLabel, String host, MultipartFile file) {
 
-//        String token = request.getHeader("token");
-//        TokenUser user = shiroService.findToken(token);
         if (!trainFileModeService.isComplete()) {
             return JsonResult.error(4000, "命令正在执行...");
         }
@@ -88,10 +89,18 @@ public class TrainFileModeController {
     }
 
 
+    /**
+     * 调用执行脚本
+     */
     private void execute() {
         trainFileModeService.exec(name1, forecastLabel1, inputLabel1, host1, file1);
     }
 
+    /**
+     * 生产控制台信息
+     * @param trainInfo
+     * @return
+     */
     @PostMapping("/addtrainInfo")
     public JSONObject addProducer(@RequestBody String trainInfo) {
 
@@ -107,6 +116,10 @@ public class TrainFileModeController {
     }
 
 
+    /**
+     * 消费控制台信息
+     * @return
+     */
     @GetMapping("/gettrainInfo")
     @ResponseBody
     public JSONObject getConsume() {
@@ -121,7 +134,7 @@ public class TrainFileModeController {
 
 
     /**
-     * 训练结果
+     * 添加训练结果
      *
      * @param history
      * @return
@@ -139,6 +152,10 @@ public class TrainFileModeController {
     }
 
 
+    /**
+     * 获取当前训练结果
+     * @return
+     */
     @GetMapping("/getHistory")
     public JSONObject getHistory() {
         try {
@@ -175,7 +192,7 @@ public class TrainFileModeController {
     @GetMapping("/getHistoryList")
     public JSONObject getDiagnosetrainhistoryList() {
         try {
-            List<Diagnosetrainhistory> list = trainFileModeService.editDiagnosetrainhistoryList();
+            List<Diagnosetrainhistory> list = trainFileModeService.getDiagnosetrainhistoryList();
             return JsonResult.successData(ResultCode.SUCCESS, list);
         } catch (Exception e) {
             return JsonResult.error(ResultCode.ERROR);
@@ -191,7 +208,7 @@ public class TrainFileModeController {
     @PostMapping("/forecasts")
     public JSONObject forecasts(@RequestBody JSONObject jsonObject) {
         try {
-            log.warn("预估请求数据:"+jsonObject.toJSONString());
+            log.warn("预估请求数据:" + jsonObject.toJSONString());
             String resultvalue = trainFileModeService.forecasts(jsonObject);
             return JsonResult.successData(ResultCode.SUCCESS, resultvalue);
         } catch (Exception e) {

+ 62 - 48
gyee-sample-impala/src/main/java/com/gyee/impala/service/master/diagnose/TrainFileModeService.java

@@ -52,13 +52,14 @@ public class TrainFileModeService {
         return isComplete;
     }
 
+    //控制太输出信息队列
     BlockingQueue<TrainInfo> infoQueue = new LinkedBlockingQueue<TrainInfo>();
+    //零时保存模型队列
     BlockingQueue<Diagnosetrainhistory> historyQueue = new LinkedBlockingQueue<Diagnosetrainhistory>();
 
-
-
     public void exec(String name, String forecastLabel, String[] inputLabel, String host, MultipartFile file) {
 
+        //判断当前是否有模型在训练
         if (!isComplete) {
             return;
         }
@@ -68,21 +69,20 @@ public class TrainFileModeService {
             String type = file.getOriginalFilename().substring(file.getOriginalFilename().lastIndexOf("."));
             /**上传文件**/
             fileService.uploadFile(file.getOriginalFilename(), file.getInputStream(), "10.155.32.14");
-
-
             //文件路径
             String filePath = config.getPath() + file.getOriginalFilename();
 
-            System.out.println(DateUtil.getCurrentDate() + "开始执行脚本"+(isComplete)+"...");
+            System.out.println(DateUtil.getCurrentDate() + "开始执行脚本" + (isComplete) + "...");
             Process p;
+            //组装调用python脚本命令
             String cmdPath = gyeeConfig.getDiagnosePath();
             String inst = JudeSystem.isWindows() ? "cmd" : "/bin/sh";
             String c = JudeSystem.isWindows() ? "/c" : "-c";
-//            String[] cmd = {"/bin/sh", "-c", "python " + cmdPath + name + ".py " + filePath};
             String[] cmd = {inst, c, "python " + cmdPath + name + ".py " + filePath};
-
             System.out.println(cmd[0] + " " + cmd[1] + " " + cmd[2]);
+            //执行命令
             p = Runtime.getRuntime().exec(cmd);
+            //监听脚本输出信息
             BufferedReader bri = new BufferedReader(new InputStreamReader(p.getInputStream()));
             BufferedReader bre = new BufferedReader(new InputStreamReader(p.getErrorStream()));
             String si = null, se = null;
@@ -102,22 +102,27 @@ public class TrainFileModeService {
             isComplete = true;
         }
 
-
-        System.out.println(DateUtil.getCurrentDate()+ "脚本执行结束"+(isComplete)+"...");
+        System.out.println(DateUtil.getCurrentDate() + "脚本执行结束" + (isComplete) + "...");
     }
 
 
-    // 控制台信息
+    /**
+     * 控制台信息
+     * @param trainInfo
+     * @throws Exception
+     */
     public void produce(TrainInfo trainInfo) throws Exception {
         // put 控制台信息到队列中
-//        trainInfo.setComplete(trainInfo.getComplete() ? trainInfo.isComplete() : isComplete);
-
-        System.out.println("add log =["+trainInfo.getLog()+"]  time = ["+trainInfo.getTime()+"] complete = ["+trainInfo.getComplete()+"]");
+        System.out.println("add log =[" + trainInfo.getLog() + "]  time = [" + trainInfo.getTime() + "] complete = [" + trainInfo.getComplete() + "]");
         infoQueue.put(trainInfo);
     }
 
 
-    // 输出控制台信息
+    /**
+     * 输出控制台信息
+     * @return list
+     * @throws Exception
+     */
     public List<TrainInfo> consume() throws Exception {
         // take方法取出一条记录
         List<TrainInfo> list = new ArrayList<>();
@@ -136,24 +141,28 @@ public class TrainFileModeService {
      * @param history
      * @throws Exception
      */
-    public void putDiagnosetrainhistory(String  history) throws Exception {
-
+    public void putDiagnosetrainhistory(String history) throws Exception {
 
         String body = JSONObject.toJSONString(history);
-        Object historys =  JSONObject.parse(body);
-
+        Object historys = JSONObject.parse(body);
         Diagnosetrainhistory d = JSONObject.parseObject(historys.toString(), Diagnosetrainhistory.class);
         d.setEnable(true);
         diagnosetrainhistoryService.insertItem(d);
         Diagnosetrainhistory dbhistory = diagnosetrainhistoryService.getHistoryByModel(d.getModel());
         historyQueue.put(dbhistory);
+
     }
 
+    /**
+     * 训练完成获取实时模型结果
+     *
+     * @return
+     * @throws Exception
+     */
     public Diagnosetrainhistory consumeHistory() throws Exception {
-
         Diagnosetrainhistory d = null;
         try {
-            if(historyQueue.size()>0){
+            if (historyQueue.size() > 0) {
                 d = historyQueue.take();
             }
         } catch (InterruptedException e) {
@@ -162,71 +171,76 @@ public class TrainFileModeService {
         return d;
     }
 
-    public String  getDiagnosetrainhistory1() throws Exception {
-
-        return "000000";
-    }
-
 
-    public void editDiagnosetrainhistory(String history) throws Exception{
+    /**
+     * 编辑模型
+     *
+     * @param history
+     * @throws Exception
+     */
+    public void editDiagnosetrainhistory(String history) throws Exception {
         Diagnosetrainhistory update = JSONObject.parseObject(history, Diagnosetrainhistory.class);
         diagnosetrainhistoryService.editDiagnosetrainhistory(update);
     }
 
 
-    public List<Diagnosetrainhistory>  editDiagnosetrainhistoryList() {
-
-        List<Diagnosetrainhistory>  list  = diagnosetrainhistoryService.getList();
-
-        return  list;
+    /**
+     * 查询历史模型
+     *
+     * @return
+     */
+    public List<Diagnosetrainhistory> getDiagnosetrainhistoryList() {
+        List<Diagnosetrainhistory> list = diagnosetrainhistoryService.getList();
+        return list;
     }
 
 
-    public String forecasts(JSONObject jsonObject){
-
+    /**
+     * 预估
+     *
+     * @param jsonObject
+     * @return
+     */
+    public String forecasts(JSONObject jsonObject) {
 
         String resultvalue = null;
         try {
-            String  name = jsonObject.get("name").toString();
-
-//            String predict = jsonObject.get("predict").toString();
-            String filename  = System.currentTimeMillis()+".json";
+            //解析请求数据
+            String name = jsonObject.get("name").toString();
+            String filename = System.currentTimeMillis() + ".json";
             //文件路径
             String filePath = config.getPath() + filename;
             fileService.uploadFile(filename, FileUtil.convertStringToInputStream(jsonObject.toJSONString()), "10.155.32.14");
             resultvalue = "";
+            //组装调用脚本命令
             String cmdPath = gyeeConfig.getDiagnosePath();
             String inst = JudeSystem.isWindows() ? "cmd" : "/bin/sh";
             String c = JudeSystem.isWindows() ? "/c" : "-c";
             String[] cmd = {inst, c, "python " + cmdPath + name + ".py " + filePath};
-
             Process p;
             System.out.println(cmd[0] + " " + cmd[1] + " " + cmd[2]);
+            //执行脚本
             p = Runtime.getRuntime().exec(cmd);
+            //脚本输出信息
             BufferedReader bri = new BufferedReader(new InputStreamReader(p.getInputStream()));
             BufferedReader bre = new BufferedReader(new InputStreamReader(p.getErrorStream()));
             String si = null, se = null;
             while ((si = bri.readLine()) != null || (se = bre.readLine()) != null) {
                 if (si != null) {
-                    System.out.println("预估返回信息:"+si);
-                    if (si.startsWith("resultvalue:")){
-                        resultvalue = si.replaceAll("resultvalue:","");
+                    System.out.println("预估返回信息:" + si);
+                    if (si.startsWith("resultvalue:")) {
+                        resultvalue = si.replaceAll("resultvalue:", "");
                     }
                 }
                 if (se != null) {
-                    System.err.println("预估返回错误信息:"+se);
+                    System.err.println("预估返回错误信息:" + se);
                 }
             }
             p.waitFor();
         } catch (Exception e) {
             e.printStackTrace();
         }
-        return  resultvalue;
-
+        return resultvalue.trim();
 
     }
-
-
-
-
 }