TrainFaultDiagnoseController.java 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. package com.gyee.impala.controller.diagnose;
  2. import com.alibaba.fastjson.JSONObject;
  3. import com.gyee.impala.common.result.JsonResult;
  4. import com.gyee.impala.common.result.ResultCode;
  5. import com.gyee.impala.common.util.DateUtil;
  6. import com.gyee.impala.model.custom.diagnose.DataInfo;
  7. import com.gyee.impala.model.custom.diagnose.ExecuteInfo;
  8. import com.gyee.impala.model.master.Casefault;
  9. import com.gyee.impala.model.master.diagnose.Diagnosepoint;
  10. import com.gyee.impala.model.master.diagnose.Diagnosetrainhistory;
  11. import com.gyee.impala.model.master.diagnose.TrainInfo;
  12. import com.gyee.impala.service.custom.diagnose.CmdFaultDiagnoseService;
  13. import com.gyee.impala.service.custom.diagnose.DataDiagnoseService;
  14. import com.gyee.impala.service.custom.diagnose.TrainDataModeService;
  15. import com.gyee.impala.service.master.CasefaultService;
  16. import lombok.extern.slf4j.Slf4j;
  17. import org.apache.kudu.client.ListTablesResponse;
  18. import org.springframework.beans.factory.annotation.Autowired;
  19. import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
  20. import org.springframework.web.bind.annotation.*;
  21. import org.springframework.web.multipart.MultipartFile;
  22. import javax.annotation.Resource;
  23. import java.util.*;
  24. import java.util.stream.Collectors;
  25. /**
  26. * 故障诊断模型训练
  27. */
  28. @Slf4j
  29. @CrossOrigin
  30. @RestController
  31. @RequestMapping("/api/diagnose")
  32. public class TrainFaultDiagnoseController {
  33. @Autowired
  34. private DataDiagnoseService dataService;
  35. @Autowired
  36. private CasefaultService casefaultService;
  37. @Autowired
  38. TrainDataModeService trainDataModeService;
  39. /**
  40. * 线程池
  41. */
  42. @Resource
  43. private ThreadPoolTaskExecutor taskExecutor;
  44. @Autowired
  45. private CmdFaultDiagnoseService trainFileModeService;
  46. private static final Object locker = new Object();
  47. private String name1;
  48. private String forecastLabel1;
  49. private String[] inputLabel1;
  50. private String host1;
  51. private MultipartFile file1;
  52. private ExecuteInfo executeInfo;
  53. private String fileName;
  54. /**
  55. * 文件模式训练接口
  56. * @param name
  57. * @param forecastLabel
  58. * @param inputLabel
  59. * @param host
  60. * @param file
  61. * @return
  62. */
  63. @PostMapping("/trainfile")
  64. @ResponseBody
  65. public JSONObject getTrainfile(String name, String forecastLabel, String[] inputLabel, String host, MultipartFile file) {
  66. if (!trainFileModeService.isComplete()) {
  67. return JsonResult.error(4000, "命令正在执行...");
  68. }
  69. if (file.isEmpty()) {
  70. return JsonResult.error(ResultCode.ERROR_FILE_NO);
  71. }
  72. try {
  73. synchronized (locker) {
  74. name1 = name;
  75. forecastLabel1 = forecastLabel;
  76. inputLabel1 = inputLabel;
  77. host1 = host;
  78. file1 = file;
  79. taskExecutor.submit(this::execute);
  80. }
  81. return JsonResult.success(ResultCode.SUCCESS);
  82. } catch (Exception e) {
  83. return JsonResult.error(ResultCode.ERROR_DATA_FILE);
  84. }
  85. }
  86. /**
  87. * 调用执行脚本
  88. */
  89. private void execute() {
  90. trainFileModeService.exec(name1, forecastLabel1, inputLabel1, host1, file1);
  91. }
  92. /** 在线训练 **/
  93. /**
  94. * 查询数据库的表
  95. * @return
  96. */
  97. @GetMapping("/tables")
  98. public JSONObject getListTables(){
  99. List<Map<String, String>> list = new ArrayList<>();
  100. try {
  101. List<ListTablesResponse.TableInfo> tables = trainDataModeService.getListTables();
  102. tables.stream().filter(a -> a.getTableName().equals("impala::gyee_sample_kudu.casefault")).forEach(obj -> {
  103. Map<String, String> map = new HashMap<>();
  104. String name = obj.getTableName().substring(obj.getTableName().lastIndexOf(".") + 1);
  105. map.put("tableId", obj.getTableId());
  106. map.put("tableName", name);
  107. list.add(map);
  108. });
  109. } catch (Exception e) { e.getMessage(); }
  110. return JsonResult.successData(ResultCode.SUCCESS,list);
  111. }
  112. /**
  113. * 查询数据库表的列
  114. * @param table
  115. * @return
  116. */
  117. @GetMapping("/columns")
  118. public JSONObject getColumns(String table){
  119. Object columns = null;
  120. try {
  121. columns = trainDataModeService.getColumns(table);
  122. } catch (Exception e) { e.getMessage(); }
  123. return JsonResult.successData(ResultCode.SUCCESS, columns);
  124. }
  125. /** 查询样本数据 **/
  126. @GetMapping("/data")
  127. public JSONObject getData(String sql){
  128. List<Casefault> list = casefaultService.executeSql(sql);
  129. return JsonResult.successData(ResultCode.SUCCESS, list);
  130. }
  131. /** 开始训练 查询 golden 所有原始数据
  132. * flag ture: 所有数据
  133. * flag false: 前10条数据
  134. * **/
  135. @PostMapping("/pointdata")
  136. public JSONObject getPointData(@RequestBody JSONObject json){
  137. if (json == null)
  138. return JsonResult.error(ResultCode.PARAM_IS_BLANK);
  139. boolean flag = json.getBooleanValue("flag");
  140. List<Diagnosepoint> points = JSONObject.parseArray(json.getJSONArray("points").toString(), Diagnosepoint.class);
  141. List<Casefault> faults = JSONObject.parseArray(json.getJSONArray("faults").toString(), Casefault.class);
  142. /** 组装数据 删除添加的故障类型**/
  143. dataService.formatUniformcode(points.stream().filter(a -> !a.getUniformcode().equals("faulttype")).collect(Collectors.toList()));
  144. executeInfo = new ExecuteInfo();
  145. Calendar cal = Calendar.getInstance();
  146. DataInfo[] dataInfos = new DataInfo[faults.size()];
  147. for (int i = 0; i < faults.size(); i++){
  148. DataInfo data = new DataInfo();
  149. data.setId(Long.valueOf(faults.get(i).getId()));
  150. data.setStationId(faults.get(i).getStationen());
  151. data.setThingId(faults.get(i).getWindturbineid());
  152. data.setModelId(faults.get(i).getModel());
  153. data.setTag(faults.get(i).getFaultcode());
  154. data.setFaultTime(faults.get(i).getStarttime());
  155. cal.setTime(DateUtil.parseStrtoDate(faults.get(i).getStarttime(), DateUtil.YYYY_MM_DD_HH_MM_SS));
  156. cal.add(Calendar.MINUTE, -10);
  157. data.setStartTs(cal.getTimeInMillis() + "");
  158. cal.add(Calendar.MINUTE, 10);
  159. data.setEndTs(cal.getTimeInMillis() + "");
  160. dataInfos[i] = data;
  161. }
  162. executeInfo.setDataInfos(dataInfos);
  163. if (flag){
  164. if (!trainFileModeService.isComplete()) {
  165. return JsonResult.error(4000, "已有正在训练的模型...");
  166. }
  167. synchronized (locker) {
  168. taskExecutor.submit(this::execute2);
  169. }
  170. return JsonResult.success(ResultCode.SUCCESS);
  171. }else {
  172. Map<String, Object> mp = dataService.getFormData(executeInfo);
  173. return JsonResult.successData(ResultCode.SUCCESS, mp);
  174. }
  175. }
  176. /**
  177. * 调用执行脚本
  178. */
  179. private void execute2() {
  180. fileName = dataService.getFormDataAll(executeInfo);
  181. trainFileModeService.exec();
  182. }
  183. /**
  184. * py 获取在线训练数据
  185. *
  186. * @return
  187. */
  188. @GetMapping("/traindata")
  189. public JSONObject getData() {
  190. Map<String, Object> map = new HashMap<>();
  191. map.put("info", this.executeInfo);
  192. map.put("filename", fileName);
  193. return JsonResult.successData(ResultCode.SUCCESS, map);
  194. }
  195. /** 在线训练 **/
  196. /**
  197. * 生产控制台信息
  198. * @param trainInfo
  199. * @return
  200. */
  201. @PostMapping("/addtrainInfo")
  202. public JSONObject addProducer(@RequestBody String trainInfo) {
  203. try {
  204. TrainInfo t = JSONObject.parseObject(trainInfo, TrainInfo.class);
  205. trainFileModeService.produce(t);
  206. return JsonResult.success(ResultCode.SUCCESS);
  207. } catch (Exception e) {
  208. return JsonResult.error(ResultCode.ERROR);
  209. }
  210. }
  211. /**
  212. * 消费控制台信息
  213. * @return
  214. */
  215. @GetMapping("/gettrainInfo")
  216. @ResponseBody
  217. public JSONObject getConsume() {
  218. try {
  219. List<TrainInfo> list = trainFileModeService.consume();
  220. System.out.println(JsonResult.successData(ResultCode.SUCCESS, list));
  221. return JsonResult.successData(ResultCode.SUCCESS, list);
  222. } catch (Exception e) {
  223. return JsonResult.error(ResultCode.ERROR);
  224. }
  225. }
  226. /**
  227. * 添加训练结果
  228. *
  229. * @param history
  230. * @return
  231. */
  232. @PostMapping("/putHistory")
  233. public JSONObject putDiagnosetrainhistory(@RequestBody String history) {
  234. try {
  235. log.info("训练模型结果:" + history);
  236. trainFileModeService.putDiagnosetrainhistory(history);
  237. return JsonResult.success(ResultCode.SUCCESS);
  238. } catch (Exception e) {
  239. log.error("请求错误", e);
  240. return JsonResult.error(ResultCode.ERROR);
  241. }
  242. }
  243. /**
  244. * 获取当前训练结果
  245. * @return
  246. */
  247. @GetMapping("/getHistory")
  248. public JSONObject getHistory() {
  249. try {
  250. Diagnosetrainhistory d = trainFileModeService.consumeHistory();
  251. return JsonResult.successData(ResultCode.SUCCESS, d);
  252. } catch (Exception e) {
  253. log.error("请求错误", e);
  254. return JsonResult.error(ResultCode.ERROR);
  255. }
  256. }
  257. /**
  258. * 预测评估
  259. *
  260. * @param jsonObject
  261. * @return
  262. */
  263. @PostMapping("/forecasts")
  264. public JSONObject forecasts(@RequestBody JSONObject jsonObject) {
  265. try {
  266. log.warn("预估请求数据:" + jsonObject.toJSONString());
  267. String resultvalue = trainFileModeService.forecasts(jsonObject);
  268. return JsonResult.successData(ResultCode.SUCCESS, resultvalue);
  269. } catch (Exception e) {
  270. return JsonResult.error(ResultCode.ERROR);
  271. }
  272. }
  273. }