浏览代码

保存训练模型记录

wangchangsheng 2 年之前
父节点
当前提交
b08a328735

+ 2 - 2
gyee-sample-impala/src/main/java/com/gyee/impala/controller/diagnose/CoordinateController.java

@@ -174,8 +174,8 @@ public class CoordinateController {
         List<Diagnosetrainhistory> ls = historyService.getAll(AlgorithmType.unsupervised);
         for (Diagnosetrainhistory he : ls) {
             History h = new History();
-            h.setCoordinate(he.getContext());
-            h.setFaultIds(Arrays.asList(he.getFaultids().split(",")));
+//            h.setCoordinate(he.getContext());
+//            h.setFaultIds(Arrays.asList(he.getFaultids().split(",")));
             h.setName(he.getName());
             hs.add(h);
         }

+ 5 - 5
gyee-sample-impala/src/main/java/com/gyee/impala/controller/diagnose/GearboxDiagnosisController.java

@@ -38,7 +38,7 @@ public class GearboxDiagnosisController {
     @GetMapping("/complete/{fname}")
     public JSONObject setCompleteFileName(@PathVariable String fname) {
         Diagnosetrainhistory he = new Diagnosetrainhistory();
-        he.setType(AlgorithmType.gearbox);
+        //he.setType(AlgorithmType.gearbox);
         he.setTime(DateUtil.getTime());
         he.setName(fname);
         historyService.insertItem(he);
@@ -51,10 +51,10 @@ public class GearboxDiagnosisController {
         List<Diagnosetrainhistory> ls = historyService.getAll(AlgorithmType.gearbox);
         for (Diagnosetrainhistory he : ls) {
             History h = new History();
-            h.setCoordinate(he.getContext());
-            if (StringUtils.isNotEmpty(he.getFaultids())) {
-                h.setFaultIds(Arrays.asList(he.getFaultids().split(",")));
-            }
+//            h.setCoordinate(he.getContext());
+//            if (StringUtils.isNotEmpty(he.getFaultids())) {
+//                h.setFaultIds(Arrays.asList(he.getFaultids().split(",")));
+//            }
             h.setName(he.getName());
             hs.add(h);
         }

+ 2 - 2
gyee-sample-impala/src/main/java/com/gyee/impala/controller/diagnose/SupervisedController.java

@@ -195,8 +195,8 @@ public class SupervisedController {
         List<Diagnosetrainhistory> ls = historyService.getAll(AlgorithmType.supervised);
         for (Diagnosetrainhistory he : ls) {
             History h = new History();
-            h.setCoordinate(he.getContext());
-            h.setFaultIds(Arrays.asList(he.getFaultids().split(",")));
+//            h.setCoordinate(he.getContext());
+//            h.setFaultIds(Arrays.asList(he.getFaultids().split(",")));
             h.setName(he.getName());
             hs.add(h);
         }

+ 84 - 4
gyee-sample-impala/src/main/java/com/gyee/impala/controller/diagnose/TrainFileModeController.java

@@ -4,13 +4,18 @@ package com.gyee.impala.controller.diagnose;
 import com.alibaba.fastjson.JSONObject;
 import com.gyee.impala.common.result.JsonResult;
 import com.gyee.impala.common.result.ResultCode;
+import com.gyee.impala.model.custom.TokenUser;
 import com.gyee.impala.model.master.diagnose.TrainInfo;
 import com.gyee.impala.service.custom.SftpFileService;
+import com.gyee.impala.service.custom.ShiroService;
 import com.gyee.impala.service.master.diagnose.TrainFileModeService;
 import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
 import org.springframework.web.bind.annotation.*;
 import org.springframework.web.multipart.MultipartFile;
 
+import javax.annotation.Resource;
+import javax.servlet.http.HttpServletRequest;
 import java.util.List;
 
 
@@ -23,11 +28,26 @@ import java.util.List;
 public class TrainFileModeController {
 
     @Autowired
-    private SftpFileService fileService;
+    private ShiroService shiroService;
+
+
+    /**
+     * 线程池
+     */
+    @Resource
+    private ThreadPoolTaskExecutor taskExecutor;
 
     @Autowired
     private TrainFileModeService trainFileModeService;
 
+    private static final Object locker = new Object();
+
+
+    private String name1;
+    private String forecastLabel1;
+    private String[] inputLabel1;
+    private String host1;
+    private MultipartFile file1;
 
 
     /**
@@ -35,13 +55,28 @@ public class TrainFileModeController {
      */
     @PostMapping("/trainfile")
     @ResponseBody
-    public JSONObject getTrainfile(String name, String forecastLabel, String[] inputLabel, String host, MultipartFile file) {
+    public JSONObject getTrainfile(HttpServletRequest request, String name, String forecastLabel, String[] inputLabel, String host, MultipartFile file) {
+
+//        String token = request.getHeader("token");
+//        TokenUser user = shiroService.findToken(token);
+        if (!trainFileModeService.isComplete()) {
+            return JsonResult.error(4000, "命令正在执行...");
+        }
+
         if (file.isEmpty()) {
             return JsonResult.error(ResultCode.ERROR_FILE_NO);
         }
 
         try {
-            trainFileModeService.exec(name, forecastLabel, inputLabel, host, file);
+            synchronized (locker) {
+                name1 = name;
+                forecastLabel1= forecastLabel;
+                inputLabel1= inputLabel;
+                host1= host;
+                file1= file;
+                taskExecutor.submit(this::execute);
+            }
+
             return JsonResult.success(ResultCode.SUCCESS);
         } catch (Exception e) {
             return JsonResult.error(ResultCode.ERROR_DATA_FILE);
@@ -49,12 +84,20 @@ public class TrainFileModeController {
     }
 
 
+
+    private void execute() {
+
+
+        trainFileModeService.exec(name1, forecastLabel1, inputLabel1, host1, file1);
+    }
+
     @PostMapping("/addtrainInfo")
     public JSONObject addProducer(@RequestBody String trainInfo) {
 
         try {
             TrainInfo t = JSONObject.parseObject(trainInfo, TrainInfo.class);
             trainFileModeService.produce(t);
+
             return JsonResult.success(ResultCode.SUCCESS);
         } catch (Exception e) {
             return JsonResult.error(ResultCode.ERROR);
@@ -63,11 +106,12 @@ public class TrainFileModeController {
     }
 
 
-    @PostMapping("/gettrainInfo")
+    @GetMapping("/gettrainInfo")
     @ResponseBody
     public JSONObject getConsume() {
         try {
             List<TrainInfo> list = trainFileModeService.consume();
+            System.out.println(JsonResult.successData(ResultCode.SUCCESS, list));
             return JsonResult.successData(ResultCode.SUCCESS, list);
         } catch (Exception e) {
             return JsonResult.error(ResultCode.ERROR);
@@ -75,4 +119,40 @@ public class TrainFileModeController {
     }
 
 
+    /**
+     * 训练最有结果
+     * @param history
+     * @return
+     */
+    @PostMapping("/putHistory")
+    public JSONObject putDiagnosetrainhistory(@RequestBody  String history){
+        try {
+            trainFileModeService.putDiagnosetrainhistory(history);
+
+            return JsonResult.success(ResultCode.SUCCESS);
+        } catch (Exception e) {
+            return JsonResult.error(ResultCode.ERROR);
+        }
+    }
+
+
+    /**
+     * 训练最有结果
+     * @param history
+     * @return
+     */
+    @PostMapping("/getHistory")
+    public JSONObject getDiagnosetrainhistory(String history){
+        try {
+            return JsonResult.success(ResultCode.SUCCESS);
+        } catch (Exception e) {
+            return JsonResult.error(ResultCode.ERROR);
+        }
+    }
+
+
+
+
+
+
 }

+ 30 - 7
gyee-sample-impala/src/main/java/com/gyee/impala/model/master/diagnose/Diagnosetrainhistory.java

@@ -16,26 +16,49 @@ import lombok.EqualsAndHashCode;
 public class Diagnosetrainhistory extends Model<Diagnosetrainhistory> {
 
     private String id;
+
+    /**
+     * 模型
+     */
+    private String model;
+
+    /**
+     * 模型编码
+     */
+    private String  code;
+
+    /**
+     * 致信度
+     */
+    private double accuracy;
+
     /**
      * 名称
      */
     private String name;
+
     /**
      * 时间
      */
     private String time;
+
     /**
-     * 算法类型
+     * 模型指标
      */
-    private AlgorithmType type;
+    private String  indicators;
+
     /**
-     * 内容
+     * 测点权重
      */
-    private String context;
+    private String pointweight;
+
     /**
-     * 故障ID
+     * 是否启用
      */
-    private String faultids;
+    private boolean enable;
 
-    private String createtime;
+    /**
+     * 备注
+     */
+    private String remark;
 }

+ 4 - 4
gyee-sample-impala/src/main/java/com/gyee/impala/service/custom/diagnose/CmdService.java

@@ -113,10 +113,10 @@ public class CmdService {
         }
         Diagnosetrainhistory he = new Diagnosetrainhistory();
         he.setName(name);
-        he.setFaultids(Arrays.stream(this.executeInfo.getDataInfos()).map(i -> i.getId() + "").collect(Collectors.joining(",")));
-        he.setContext(context);
-        he.setTime(DateUtil.getTime());
-        he.setType(AlgorithmType.unsupervised);
+//        he.setFaultids(Arrays.stream(this.executeInfo.getDataInfos()).map(i -> i.getId() + "").collect(Collectors.joining(",")));
+//        he.setContext(context);
+//        he.setTime(DateUtil.getTime());
+//        he.setType(AlgorithmType.unsupervised);
         return he;
     }
 }

+ 4 - 4
gyee-sample-impala/src/main/java/com/gyee/impala/service/custom/diagnose/SupervisedCmdService.java

@@ -102,10 +102,10 @@ public class SupervisedCmdService {
         }
         Diagnosetrainhistory he = new Diagnosetrainhistory();
         he.setName(name);
-        he.setFaultids(Arrays.stream(this.executeInfo.getDataInfos()).map(i -> i.getId() + "").collect(Collectors.joining(",")));
-        he.setContext(supervised);
-        he.setTime(DateUtil.getTime());
-        he.setType(AlgorithmType.supervised);
+//        he.setFaultids(Arrays.stream(this.executeInfo.getDataInfos()).map(i -> i.getId() + "").collect(Collectors.joining(",")));
+//        he.setContext(supervised);
+//        he.setTime(DateUtil.getTime());
+//        he.setType(AlgorithmType.supervised);
         return he;
     }
 }

+ 7 - 4
gyee-sample-impala/src/main/java/com/gyee/impala/service/impl/master/diagnose/DiagnosetrainhistoryServiceImpl.java

@@ -72,12 +72,15 @@ public class DiagnosetrainhistoryServiceImpl extends ServiceImpl<Diagnosetrainhi
         // 获取Row对象,设置插入的值
         PartialRow row = insert.getRow();
         row.addObject("id", SnowFlakeUtil.generateId());
+        row.addObject("model", obj.getModel());
+        row.addObject("code", obj.getCode());
+        row.addObject("accuracy", obj.getAccuracy());
         row.addObject("name", obj.getName());
         row.addObject("time", obj.getTime());
-        row.addObject("type", obj.getType());
-        row.addObject("context", obj.getContext());
-        row.addObject("faultids", obj.getFaultids());
-        row.addObject("createtime", DateUtil.getCurrentDate());
+        row.addObject("indicators", obj.getIndicators());
+        row.addObject("pointweight", obj.getPointweight());
+        row.addObject("enable", obj.isEnable());
+        row.addObject("remark", obj.getRemark());
 
         // 先不提交kudu
         kuduSession.apply(insert);

+ 69 - 55
gyee-sample-impala/src/main/java/com/gyee/impala/service/master/diagnose/TrainFileModeService.java

@@ -1,20 +1,18 @@
 package com.gyee.impala.service.master.diagnose;
 
 
+import com.alibaba.fastjson.JSONObject;
 import com.gyee.impala.common.config.GyeeConfig;
 import com.gyee.impala.common.config.jsch.JSchConfig;
+import com.gyee.impala.model.master.diagnose.Diagnosetrainhistory;
 import com.gyee.impala.model.master.diagnose.TrainInfo;
-import com.gyee.impala.model.master.diagnose.TrainParam;
 import com.gyee.impala.service.custom.SftpFileService;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.stereotype.Service;
 import org.springframework.web.multipart.MultipartFile;
-
 import java.io.BufferedReader;
 import java.io.InputStreamReader;
-import java.util.ArrayList;
-import java.util.Date;
-import java.util.List;
+import java.util.*;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.LinkedBlockingQueue;
 
@@ -28,6 +26,9 @@ public class TrainFileModeService {
     @Autowired
     private JSchConfig config;
 
+    @Autowired
+    private DiagnosetrainhistoryService diagnosetrainhistoryService;
+
     /**
      * 保存脚本的位置
      */
@@ -40,38 +41,40 @@ public class TrainFileModeService {
      */
     private boolean isComplete = true;
 
+    public boolean isComplete() {
+        return isComplete;
+    }
 
-    BlockingQueue<TrainInfo> console = new LinkedBlockingQueue<TrainInfo>();
+    BlockingQueue<TrainInfo> infoQueue = new LinkedBlockingQueue<TrainInfo>();
+    BlockingQueue<Diagnosetrainhistory> historyQueue = new LinkedBlockingQueue<Diagnosetrainhistory>();
 
+    Map<String, BlockingQueue<TrainInfo>> mapConsole = new HashMap<>();
 
-    public synchronized void exec(String name, String forecastLabel, String[] inputLabel, String host, MultipartFile file) throws Exception {
+
+    public void exec(String name, String forecastLabel, String[] inputLabel, String host, MultipartFile file) {
 
         if (!isComplete) {
             return;
         }
-        isComplete = false;
-        //获取上传文件的文件名
-        String type = file.getOriginalFilename().substring(file.getOriginalFilename().lastIndexOf("."));
-        ;
-        /**上传文件**/
-        //fileService.uploadFile(file.getOriginalFilename(), file.getInputStream(),"10.155.32.14");
-
-        //文件路径
-        String filePath = config.getPath() + file.getOriginalFilename();
-
-        System.out.println(new Date() + "开始执行脚本...");
-        Process p;
-        String cmdPath = gyeeConfig.getDiagnosePath();
-
-
         try {
-            String[] cmd = {"/bin/sh", "-c", "python " + cmdPath + name + ".py " + filePath};
-//            String cmd = "cmd /c python " + cmdPath + name + ".py " + filePath;
-
-            System.out.println(cmd[0] + " " + cmd[1] + " " + cmd[2]);
+            isComplete = false;
+            //获取上传文件的文件名
+            String type = file.getOriginalFilename().substring(file.getOriginalFilename().lastIndexOf("."));
+            ;
+            /**上传文件**/
+            fileService.uploadFile(file.getOriginalFilename(), file.getInputStream(), "10.155.32.14");
+
+            //文件路径
+            String filePath = config.getPath() + file.getOriginalFilename();
+
+            System.out.println(new Date() + "开始执行脚本...");
+            Process p;
+            String cmdPath = gyeeConfig.getDiagnosePath();
+//            String[] cmd = {"/bin/sh", "-c", "python " + cmdPath + name + ".py " + filePath};
+            String cmd = "cmd /c python " + cmdPath + name + ".py " + filePath;
+
+//            System.out.println(cmd[0] + " " + cmd[1] + " " + cmd[2]);
             p = Runtime.getRuntime().exec(cmd);
-
-
             BufferedReader bri = new BufferedReader(new InputStreamReader(p.getInputStream()));
             BufferedReader bre = new BufferedReader(new InputStreamReader(p.getErrorStream()));
             String si = null, se = null;
@@ -80,17 +83,37 @@ public class TrainFileModeService {
                     System.out.println(si);
                 }
                 if (se != null) {
-                    System.out.println(se);
+                    System.err.println(se);
                 }
             }
 
-            Thread.sleep(10000);
+
             p.waitFor();
+
+            for (int l = 0; l < 20; l++) {
+                TrainInfo t = new TrainInfo();
+                t.setLog("请求第" + l + "次");
+                t.setTime("" + (20 - l));
+                t.setComplete(isComplete);
+                produce(t);
+                Thread.sleep(1000);
+            }
+
         } catch (Exception e) {
             e.printStackTrace();
         } finally {
             isComplete = true;
+            try {
+                TrainInfo t1 = new TrainInfo();
+                t1.setLog("请求第20次");
+                t1.setTime("0");
+                t1.setComplete(isComplete);
+                produce(t1);
+            } catch (Exception e) {
+                e.printStackTrace();
+            }
         }
+
         System.out.println(new Date() + "脚本执行结束...");
     }
 
@@ -98,8 +121,8 @@ public class TrainFileModeService {
     // 控制台信息
     public void produce(TrainInfo trainInfo) throws Exception {
         // put 控制台信息到队列中
-        trainInfo.setComplete(isComplete);
-        console.put(trainInfo);
+        trainInfo.setComplete(trainInfo.isComplete() ? trainInfo.isComplete() : isComplete);
+        infoQueue.put(trainInfo);
     }
 
 
@@ -107,36 +130,27 @@ public class TrainFileModeService {
     public List<TrainInfo> consume() throws Exception {
         // take方法取出一条记录
         List<TrainInfo> list = new ArrayList<>();
-        int sise = console.size();
+        int sise = infoQueue.size();
         for (int i = 0; i < sise; i++) {
-            list.add(console.take());
+            TrainInfo info = infoQueue.take();
+            list.add(info);
+            if (info.isComplete()) {
+                break;
+            }
         }
         return list;
     }
 
-  /*  // 定义苹果消费者
-    class Consumer implements Runnable {
-        private String instance;
-        private TrainInfo basket;
 
-        public Consumer(String instance, TrainInfo basket) {
-            this.instance = instance;
-            this.basket = basket;
-        }
+    public void putDiagnosetrainhistory(String history) throws Exception{
+
+        Diagnosetrainhistory d = JSONObject.parseObject(history, Diagnosetrainhistory.class);
+        d.setEnable(true);
+        diagnosetrainhistoryService.insertItem(d);
+        historyQueue.put(d);
+
 
-        public void run() {
-            try {
-                while (true) {
-                    // 消费苹果
-                    System.out.println("消费者准备消费苹果:" + instance);
-                    System.out.println(basket.getLog());
-                    System.out.println("!消费者消费苹果完毕:" + instance);
 
-                }
-            } catch (Exception ex) {
-                System.out.println("Consumer Interrupted");
-            }
-        }
     }
-*/
+
 }

+ 9 - 5
gyee-sample-impala/src/main/resources/mapper/master/DiagnosetrainhistoryMapper.xml

@@ -5,17 +5,21 @@
     <!-- 通用查询映射结果 -->
     <resultMap id="BaseResultMap" type="com.gyee.impala.model.master.diagnose.Diagnosetrainhistory">
         <id column="id" property="id" />
+        <result column="model" property="model" />
+        <result column="code" property="code" />
+        <result column="accuracy" property="accuracy" />
         <result column="name" property="name" />
         <result column="time" property="time" />
-        <result column="type" property="type" />
-        <result column="context" property="context" />
-        <result column="faultids" property="faultids" />
-        <result column="createtime" property="createtime" />
+        <result column="indicators" property="indicators" />
+        <result column="pointweight" property="pointweight" />
+        <result column="enable" property="enable" />
+        <result column="remark" property="remark" />
+
     </resultMap>
 
     <!-- 通用查询结果列 -->
     <sql id="Base_Column_List">
-        id,name,time,type,context,faultids,createtime
+        id,name,time,type
     </sql>
 
 </mapper>