浏览代码

训练模型逻辑修改

wangchangsheng 2 年之前
父节点
当前提交
f3332aa987

+ 26 - 3
gyee-sample-impala/src/main/java/com/gyee/impala/controller/diagnose/TrainDataModeController.java

@@ -1,9 +1,15 @@
 package com.gyee.impala.controller.diagnose;
 
 
-import org.springframework.web.bind.annotation.CrossOrigin;
-import org.springframework.web.bind.annotation.RequestMapping;
-import org.springframework.web.bind.annotation.RestController;
+import com.alibaba.fastjson.JSONObject;
+import com.gyee.impala.common.result.JsonResult;
+import com.gyee.impala.common.result.ResultCode;
+import com.gyee.impala.service.master.diagnose.TrainDataModeService;
+import org.apache.kudu.client.ListTablesResponse;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.web.bind.annotation.*;
+
+import java.util.List;
 
 /**
  * 数据源方式训练模型
@@ -12,4 +18,21 @@ import org.springframework.web.bind.annotation.RestController;
 @RestController
 @RequestMapping("/api/traindatamode")
 public class TrainDataModeController {
+
+
+    @Autowired
+    TrainDataModeService trainDataModeService;
+
+    @GetMapping("/getListTables")
+    public JSONObject getListTables(){
+
+        try {
+            List<ListTablesResponse.TableInfo> tables = trainDataModeService.getListTables();
+
+            return JsonResult.successData(ResultCode.SUCCESS,tables);
+        } catch (Exception e) {
+            return JsonResult.error(ResultCode.ERROR);
+        }
+
+    }
 }

+ 28 - 7
gyee-sample-impala/src/main/java/com/gyee/impala/controller/diagnose/TrainFileModeController.java

@@ -5,6 +5,7 @@ import com.alibaba.fastjson.JSONObject;
 import com.gyee.impala.common.result.JsonResult;
 import com.gyee.impala.common.result.ResultCode;
 import com.gyee.impala.model.custom.TokenUser;
+import com.gyee.impala.model.master.diagnose.Diagnosetrainhistory;
 import com.gyee.impala.model.master.diagnose.TrainInfo;
 import com.gyee.impala.service.custom.SftpFileService;
 import com.gyee.impala.service.custom.ShiroService;
@@ -86,8 +87,6 @@ public class TrainFileModeController {
 
 
     private void execute() {
-
-
         trainFileModeService.exec(name1, forecastLabel1, inputLabel1, host1, file1);
     }
 
@@ -120,7 +119,7 @@ public class TrainFileModeController {
 
 
     /**
-     * 训练最有结果
+     * 训练结果
      * @param history
      * @return
      */
@@ -128,7 +127,6 @@ public class TrainFileModeController {
     public JSONObject putDiagnosetrainhistory(@RequestBody  String history){
         try {
             trainFileModeService.putDiagnosetrainhistory(history);
-
             return JsonResult.success(ResultCode.SUCCESS);
         } catch (Exception e) {
             return JsonResult.error(ResultCode.ERROR);
@@ -137,14 +135,37 @@ public class TrainFileModeController {
 
 
     /**
-     * 训练最有结果
+     * 获取最终结果
      * @param history
      * @return
      */
-    @PostMapping("/getHistory")
+    @GetMapping("/getHistory")
     public JSONObject getDiagnosetrainhistory(String history){
         try {
-            return JsonResult.success(ResultCode.SUCCESS);
+            Diagnosetrainhistory d = trainFileModeService.getDiagnosetrainhistory(history);
+            return JsonResult.successData(ResultCode.SUCCESS, d);
+        } catch (Exception e) {
+            return JsonResult.error(ResultCode.ERROR);
+        }
+    }
+
+
+    /**
+     * 编辑最终
+     * @param history
+     * @return
+     */
+    @PostMapping("/editHistory")
+    public JSONObject editDiagnosetrainhistory(String history){
+        try {
+
+            int code = trainFileModeService.editDiagnosetrainhistory(history);
+            if(code >0){
+                return JsonResult.success(ResultCode.SUCCESS);
+            }else {
+                return JsonResult.error(ResultCode.ERROR);
+            }
+
         } catch (Exception e) {
             return JsonResult.error(ResultCode.ERROR);
         }

+ 1 - 1
gyee-sample-impala/src/main/java/com/gyee/impala/model/master/diagnose/TrainInfo.java

@@ -6,5 +6,5 @@ import lombok.Data;
 public class TrainInfo {
     private String log;
     private String time;
-    private  boolean complete;
+    private String complete;
 }

+ 1 - 1
gyee-sample-impala/src/main/java/com/gyee/impala/service/custom/SftpFileService.java

@@ -110,7 +110,7 @@ public class SftpFileService {
      */
     public void uploadFile(String fileName, InputStream in,String host) {
         if (JudeSystem.isWindows()){
-            ChannelSftp sftp = config.getSftpSocket(host,"root","gyee2021");
+            ChannelSftp sftp = config.getSftpSocket(host,"gyee","gyee2021");
             try {
                 boolean dirs = this.createDirs(config.getPath(), sftp);
                 if (!dirs) {

+ 23 - 0
gyee-sample-impala/src/main/java/com/gyee/impala/service/impl/master/diagnose/DiagnosetrainhistoryServiceImpl.java

@@ -58,6 +58,29 @@ public class DiagnosetrainhistoryServiceImpl extends ServiceImpl<Diagnosetrainhi
         }
     }
 
+    @Override
+    public Diagnosetrainhistory getHistoryByModel(String model) {
+        QueryWrapper<Diagnosetrainhistory> wrapper = new QueryWrapper<>();
+        wrapper.eq("model",model);
+        try{
+            return baseMapper.selectList(wrapper).get(0);
+        } catch (Exception e){
+            log.error(e.getMessage());
+            throw new CustomException(ResultCode.ERROR_DATA);
+        }
+    }
+
+    @Override
+    public int editDiagnosetrainhistory(Diagnosetrainhistory history) {
+        try{
+            return baseMapper.updateById(history);
+        } catch (Exception e){
+            log.error(e.getMessage());
+            throw new CustomException(ResultCode.ERROR_DATA);
+        }
+    }
+
+
     /**
      * 由于mybatis-plus存储的中文乱码
      * 采用原生写法

+ 14 - 0
gyee-sample-impala/src/main/java/com/gyee/impala/service/master/diagnose/DiagnosetrainhistoryService.java

@@ -27,4 +27,18 @@ public interface DiagnosetrainhistoryService extends IService<Diagnosetrainhisto
      * @param id
      */
     void deleteById(String id);
+
+    /**
+     * 通过mode获取模型
+     * @param model
+     */
+    Diagnosetrainhistory getHistoryByModel(String model);
+
+
+    /**
+     * 编辑模型
+     * @param history
+     */
+    int editDiagnosetrainhistory(Diagnosetrainhistory history);
+
 }

+ 23 - 0
gyee-sample-impala/src/main/java/com/gyee/impala/service/master/diagnose/TrainDataModeService.java

@@ -0,0 +1,23 @@
+package com.gyee.impala.service.master.diagnose;
+
+import com.gyee.impala.common.config.datasource.KuduDataSourceConfig;
+import org.apache.kudu.client.ListTablesResponse;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.stereotype.Service;
+
+import java.util.List;
+
+@Service
+public class TrainDataModeService {
+
+    @Autowired
+    private KuduDataSourceConfig kuduConfig;
+
+
+    public List<ListTablesResponse.TableInfo> getListTables() throws Exception {
+
+        List<ListTablesResponse.TableInfo> ld = kuduConfig.kuduClient.getTablesList().getTableInfosList();
+
+        return ld;
+    }
+}

+ 34 - 29
gyee-sample-impala/src/main/java/com/gyee/impala/service/master/diagnose/TrainFileModeService.java

@@ -4,12 +4,14 @@ package com.gyee.impala.service.master.diagnose;
 import com.alibaba.fastjson.JSONObject;
 import com.gyee.impala.common.config.GyeeConfig;
 import com.gyee.impala.common.config.jsch.JSchConfig;
+import com.gyee.impala.common.util.JudeSystem;
 import com.gyee.impala.model.master.diagnose.Diagnosetrainhistory;
 import com.gyee.impala.model.master.diagnose.TrainInfo;
 import com.gyee.impala.service.custom.SftpFileService;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.stereotype.Service;
 import org.springframework.web.multipart.MultipartFile;
+
 import java.io.BufferedReader;
 import java.io.InputStreamReader;
 import java.util.*;
@@ -60,20 +62,23 @@ public class TrainFileModeService {
             isComplete = false;
             //获取上传文件的文件名
             String type = file.getOriginalFilename().substring(file.getOriginalFilename().lastIndexOf("."));
-            ;
             /**上传文件**/
             fileService.uploadFile(file.getOriginalFilename(), file.getInputStream(), "10.155.32.14");
 
+
             //文件路径
             String filePath = config.getPath() + file.getOriginalFilename();
 
             System.out.println(new Date() + "开始执行脚本...");
             Process p;
             String cmdPath = gyeeConfig.getDiagnosePath();
+            String inst = JudeSystem.isWindows() ? "cmd" : "/bin/sh";
+            String c = JudeSystem.isWindows() ? "/c" : "-c";
 //            String[] cmd = {"/bin/sh", "-c", "python " + cmdPath + name + ".py " + filePath};
-            String cmd = "cmd /c python " + cmdPath + name + ".py " + filePath;
+            String[] cmd = {inst, c, "python " + cmdPath + name + ".py " + filePath};
 
-//            System.out.println(cmd[0] + " " + cmd[1] + " " + cmd[2]);
+            Thread.sleep(3000);
+            System.out.println(cmd[0] + " " + cmd[1] + " " + cmd[2]);
             p = Runtime.getRuntime().exec(cmd);
             BufferedReader bri = new BufferedReader(new InputStreamReader(p.getInputStream()));
             BufferedReader bre = new BufferedReader(new InputStreamReader(p.getErrorStream()));
@@ -84,34 +89,16 @@ public class TrainFileModeService {
                 }
                 if (se != null) {
                     System.err.println(se);
+                    TrainInfo t = JSONObject.parseObject(se, TrainInfo.class);
+                    produce(t);
                 }
             }
-
-
             p.waitFor();
 
-            for (int l = 0; l < 20; l++) {
-                TrainInfo t = new TrainInfo();
-                t.setLog("请求第" + l + "次");
-                t.setTime("" + (20 - l));
-                t.setComplete(isComplete);
-                produce(t);
-                Thread.sleep(1000);
-            }
-
         } catch (Exception e) {
             e.printStackTrace();
         } finally {
             isComplete = true;
-            try {
-                TrainInfo t1 = new TrainInfo();
-                t1.setLog("请求第20次");
-                t1.setTime("0");
-                t1.setComplete(isComplete);
-                produce(t1);
-            } catch (Exception e) {
-                e.printStackTrace();
-            }
         }
 
         System.out.println(new Date() + "脚本执行结束...");
@@ -121,7 +108,9 @@ public class TrainFileModeService {
     // 控制台信息
     public void produce(TrainInfo trainInfo) throws Exception {
         // put 控制台信息到队列中
-        trainInfo.setComplete(trainInfo.isComplete() ? trainInfo.isComplete() : isComplete);
+//        trainInfo.setComplete(trainInfo.getComplete() ? trainInfo.isComplete() : isComplete);
+
+        System.out.println("add log =["+trainInfo.getLog()+"]  time = ["+trainInfo.getTime()+"] complete = ["+isComplete+"]");
         infoQueue.put(trainInfo);
     }
 
@@ -134,23 +123,39 @@ public class TrainFileModeService {
         for (int i = 0; i < sise; i++) {
             TrainInfo info = infoQueue.take();
             list.add(info);
-            if (info.isComplete()) {
-                break;
-            }
         }
         return list;
     }
 
 
-    public void putDiagnosetrainhistory(String history) throws Exception{
+    /**
+     * 保存训练结果
+     *
+     * @param history
+     * @throws Exception
+     */
+    public void putDiagnosetrainhistory(String history) throws Exception {
 
         Diagnosetrainhistory d = JSONObject.parseObject(history, Diagnosetrainhistory.class);
         d.setEnable(true);
         diagnosetrainhistoryService.insertItem(d);
-        historyQueue.put(d);
+        Diagnosetrainhistory dbhistory = diagnosetrainhistoryService.getHistoryByModel(d.getModel());
+        historyQueue.put(dbhistory);
+    }
+
+    public Diagnosetrainhistory getDiagnosetrainhistory(String id) throws Exception {
+        Diagnosetrainhistory history = historyQueue.take();
+        return history;
+    }
 
 
+    public int editDiagnosetrainhistory(String history) {
+        Diagnosetrainhistory update = JSONObject.parseObject(history, Diagnosetrainhistory.class);
 
+        int code  = diagnosetrainhistoryService.editDiagnosetrainhistory(update);
+
+        return  code;
     }
 
+
 }